Spaces:
Sleeping
Sleeping
File size: 6,216 Bytes
09830df d5e0e3a 09830df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
import duckdb
import pandas as pd
from tabulate import tabulate
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from sentence_transformers import SentenceTransformer, util
import threading
# --------------------------
# Setup: Load data and models
# --------------------------
df = pd.read_excel("mau_bao_cao.xlsx")
conn = duckdb.connect('mau_bao_cao.db')
conn.execute("""\
CREATE TABLE IF NOT EXISTS production_data AS
SELECT * FROM read_xlsx('mau_bao_cao.xlsx');
""")
# Load embedding model for computing embeddings.
embedding_model = SentenceTransformer("intfloat/multilingual-e5-large-instruct")
column_names = df.columns.tolist()
column_embeddings = embedding_model.encode(column_names, convert_to_tensor=True)
row_texts = df.apply(lambda row: " | ".join(row.astype(str)), axis=1)
row_embeddings = embedding_model.encode(row_texts.tolist(), convert_to_tensor=True)
# Load Qwen model and tokenizer for conversational generation.
fc_model = AutoModelForCausalLM.from_pretrained('Qwen/Qwen2.5-3B-Instruct', torch_dtype=torch.float16, device_map="auto")
fc_tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen2.5-3B-Instruct')
# --------------------------
# Define the streaming generation function
# --------------------------
def generate_response(user_query: str):
"""
A generator function that:
1. Embeds the query.
2. Selects top matching columns and rows from the data.
3. Prepares a system prompt with the extracted table.
4. Uses TextIteratorStreamer to stream the generated response.
Yields the partial generated text as it is updated.
"""
# 1. Embed the user query.
question_embedding = embedding_model.encode(user_query, convert_to_tensor=True)
# 2. Find best matching columns (top 10).
k = 3
column_similarities = util.cos_sim(question_embedding, column_embeddings)[0]
best_column_indices = torch.topk(column_similarities, k).indices.tolist()
best_column_names = [column_names[i] for i in best_column_indices]
# 3. Select top matching rows (top 10).
row_similarities = util.cos_sim(question_embedding, row_embeddings).squeeze(0)
m = 10
best_row_indices = torch.topk(row_similarities, m).indices.tolist()
filtered_df = df.iloc[best_row_indices][best_column_names]
# 4. Format the filtered data as a table.
table_output = tabulate(filtered_df, headers=best_column_names, tablefmt="grid")
# 5. Build the system prompt.
system_prompt = f"""\
Bạn là một trợ lý báo cáo sản xuất thông minh, chuyên phân tích và tổng hợp dữ liệu một cách rõ ràng, dễ hiểu.
Dưới đây là dữ liệu bạn cần phân tích:
🔹 Các cột dữ liệu liên quan: {', '.join(best_column_names)}
🔹 Bảng dữ liệu:
{table_output}
📌 Nhiệm vụ của bạn:
Tóm tắt số liệu quan trọng, tránh liệt kê máy móc.
Nhận xét về xu hướng và điểm bất thường.
Nếu có thể, đề xuất giải pháp hoặc hành động tiếp theo.
📊 Cách trả lời:
✔️ Tự nhiên, dễ hiểu, không quá cứng nhắc.
✔️ Không cần nhắc lại bảng dữ liệu, hãy diễn giải nó.
✔️ Trả lời đúng trọng tâm, không dư thừa.
Ví dụ:
🔹 "Hôm nay, sản lượng đạt 95%, cao hơn 5% so với tuần trước."
⚠️ "Dây chuyền A đang giảm hiệu suất, cần theo dõi thêm."
🚀 "Nếu duy trì tốc độ này, sản lượng tháng có thể vượt kế hoạch 10%."
Bạn đã sẵn sàng phân tích và đưa ra báo cáo!
"""
# 6. Create the conversation messages.
messages = [
{'role': 'system', 'content': system_prompt},
{'role': 'user', 'content': user_query}
]
# 7. Prepare the prompt for the model.
response_template = fc_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
response_inputs = fc_tokenizer(response_template, return_tensors="pt").to(fc_model.device)
# 8. Use TextIteratorStreamer to yield tokens as they are generated.
streamer = TextIteratorStreamer(fc_tokenizer, skip_prompt=True, skip_special_tokens=True)
# Start generation in a separate thread so we can yield tokens as they arrive.
thread = threading.Thread(
target=lambda: fc_model.generate(
**response_inputs,
max_new_tokens=512,
temperature=1,
top_p=0.95,
streamer=streamer
)
)
thread.start()
# 9. Yield tokens incrementally.
collected_text = ""
for new_text in streamer:
collected_text += new_text
yield collected_text
# --------------------------
# Build the Gradio conversation interface
# --------------------------
def chat_interface(user_message, history):
"""
A generator function for Gradio that:
- Updates the conversation history with the user message.
- Streams the model's response token-by-token in real time.
The history is maintained as a list of pairs [user_message, bot_response].
"""
# Create a new conversation entry with user message and an empty bot response.
history.append([user_message, ""])
# Yield the initial state.
yield "", history
# Stream tokens from the generate_response generator.
for partial_response in generate_response(user_message):
# Update the latest conversation entry with the partial bot response.
history[-1][1] = partial_response
yield "", history
with gr.Blocks() as demo:
gr.Markdown("## Gradio Chat Interface with Real-Time Streaming")
chatbot = gr.Chatbot()
state = gr.State([]) # maintain conversation history as a list of pairs
with gr.Row():
txt = gr.Textbox(show_label=False, placeholder="Nhập câu hỏi của bạn...", container=False)
send_btn = gr.Button("Gửi")
# Both submit and click trigger the chat_interface generator.
txt.submit(chat_interface, [txt, state], [txt, chatbot], queue=True)
send_btn.click(chat_interface, [txt, state], [txt, chatbot], queue=True)
demo.launch()
|