Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,64 +1,129 @@
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
-
from
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
|
|
8 |
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
message,
|
12 |
-
history: list[tuple[str, str]],
|
13 |
-
system_message,
|
14 |
-
max_tokens,
|
15 |
-
temperature,
|
16 |
-
top_p,
|
17 |
-
):
|
18 |
-
messages = [{"role": "system", "content": system_message}]
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
25 |
|
26 |
-
|
|
|
27 |
|
28 |
-
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
)
|
37 |
-
|
|
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
41 |
|
|
|
|
|
|
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
if __name__ == "__main__":
|
64 |
-
|
|
|
|
1 |
+
import os
|
2 |
+
import subprocess
|
3 |
import gradio as gr
|
4 |
+
from transformers import AutoTokenizer
|
5 |
+
from optimum.intel.openvino import OVModelForCausalLM
|
6 |
+
from sentence_transformers import SentenceTransformer
|
7 |
+
import faiss
|
8 |
+
import numpy as np
|
9 |
+
import warnings
|
10 |
|
11 |
+
warnings.filterwarnings(
|
12 |
+
"ignore",
|
13 |
+
category=DeprecationWarning,
|
14 |
+
message="__array__ implementation doesn't accept a copy keyword"
|
15 |
+
)
|
16 |
|
17 |
+
# 設定模型 ID 與轉換後存檔路徑(8-bit 量化版)
|
18 |
+
model_id = "agentica-org/DeepScaleR-1.5B-Preview"
|
19 |
+
export_path = "exported_model_openvino_int8"
|
20 |
|
21 |
+
print("Loading model as OpenVINO int8 (8-bit) model...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
if os.path.exists(export_path) and os.listdir(export_path):
|
24 |
+
print(f"Found quantized OpenVINO model at '{export_path}', loading it...")
|
25 |
+
model = OVModelForCausalLM.from_pretrained(export_path, device_map="auto", use_cache=False)
|
26 |
+
else:
|
27 |
+
print("No quantized model found, exporting and quantizing to OpenVINO int8 now...")
|
28 |
+
# 透過 optimum-cli 導出並量化模型(此命令行參數根據你的任務可能需要調整)
|
29 |
+
command = [
|
30 |
+
"optimum-cli", "export", "openvino",
|
31 |
+
"--model", model_id,
|
32 |
+
"--task", "text-generation",
|
33 |
+
"--weight-format", "int8",
|
34 |
+
export_path
|
35 |
+
]
|
36 |
+
subprocess.run(command, check=True)
|
37 |
+
print(f"Quantized model saved to '{export_path}'.")
|
38 |
+
model = OVModelForCausalLM.from_pretrained(export_path, device_map="auto", use_cache=False)
|
39 |
|
40 |
+
print("Loading tokenizer...")
|
41 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
|
42 |
|
43 |
+
# 載入向量模型(用於將文本轉換為向量)
|
44 |
+
encoder = SentenceTransformer("all-MiniLM-L6-v2")
|
45 |
|
46 |
+
# FAQ 知識庫(問題 + 回答)
|
47 |
+
faq_data = [
|
48 |
+
("What is FAISS?", "FAISS is a library for efficient similarity search and clustering of dense vectors."),
|
49 |
+
("How does FAISS work?", "FAISS uses indexing structures to quickly retrieve the nearest neighbors of a query vector."),
|
50 |
+
("Can FAISS run on GPU?", "Yes, FAISS supports GPU acceleration for faster computation."),
|
51 |
+
("What is OpenVINO?", "OpenVINO is an inference engine optimized for Intel hardware."),
|
52 |
+
("How to fine-tune a model?", "Fine-tuning involves training a model on a specific dataset to adapt it to a particular task."),
|
53 |
+
("What is the best way to optimize inference speed?", "Using quantization and model distillation can significantly improve inference speed.")
|
54 |
+
]
|
55 |
|
56 |
+
# 將 FAQ 問題轉換為向量
|
57 |
+
faq_questions = [q for q, _ in faq_data]
|
58 |
+
faq_answers = [a for _, a in faq_data]
|
59 |
+
faq_vectors = np.array(encoder.encode(faq_questions)).astype("float32")
|
60 |
|
61 |
+
# 建立 FAISS 索引(使用 L2 距離)
|
62 |
+
d = faq_vectors.shape[1] # 向量維度
|
63 |
+
index = faiss.IndexFlatL2(d)
|
64 |
+
index.add(faq_vectors)
|
65 |
|
66 |
+
# 對話歷史記錄
|
67 |
+
history = []
|
68 |
+
|
69 |
+
# 查詢函數:先嘗試從 FAQ 中檢索答案,若無匹配則使用 OpenVINO 模型生成回答
|
70 |
+
def respond(prompt):
|
71 |
+
global history
|
72 |
+
# 將輸入轉換為向量,並使用 FAISS 查詢最相近的 FAQ 問題
|
73 |
+
query_vector = np.array(encoder.encode([prompt])).astype("float32")
|
74 |
+
D, I = index.search(query_vector, 1)
|
75 |
+
|
76 |
+
if D[0][0] < 1.0:
|
77 |
+
response = faq_answers[I[0][0]]
|
78 |
+
else:
|
79 |
+
# 若 FAQ 無匹配,則使用 OpenVINO 模型生成回答
|
80 |
+
messages = [{"role": "system", "content": "Answer the question in English only."}]
|
81 |
+
for user_text, assistant_text in history:
|
82 |
+
messages.append({"role": "user", "content": user_text})
|
83 |
+
messages.append({"role": "assistant", "content": assistant_text})
|
84 |
+
messages.append({"role": "user", "content": prompt})
|
85 |
+
|
86 |
+
# 將對話訊息組成一個 prompt(以換行分隔)
|
87 |
+
chat_prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
|
88 |
+
model_inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device)
|
89 |
+
generated_ids = model.generate(
|
90 |
+
**model_inputs,
|
91 |
+
max_new_tokens=512,
|
92 |
+
temperature=0.7,
|
93 |
+
top_p=0.9,
|
94 |
+
do_sample=True
|
95 |
+
)
|
96 |
+
response = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
|
97 |
+
|
98 |
+
history.append((prompt, response))
|
99 |
+
return response
|
100 |
+
|
101 |
+
# 清除對話歷史記錄
|
102 |
+
def clear_history():
|
103 |
+
global history
|
104 |
+
history = []
|
105 |
+
return "History cleared!"
|
106 |
|
107 |
+
# 建立 Gradio 介面
|
108 |
+
with gr.Blocks() as demo:
|
109 |
+
gr.Markdown("# DeepScaleR-1.5B-Preview (OpenVINO int8) Chatbot with FAISS FAQ")
|
110 |
+
|
111 |
+
with gr.Tabs():
|
112 |
+
with gr.TabItem("Chat"):
|
113 |
+
chat_interface = gr.Interface(
|
114 |
+
fn=respond,
|
115 |
+
inputs=gr.Textbox(label="Prompt", placeholder="Enter your message..."),
|
116 |
+
outputs=gr.Textbox(label="Response", interactive=False),
|
117 |
+
api_name="hchat",
|
118 |
+
title="DeepScaleR-1.5B-Preview (OpenVINO int8) Chatbot",
|
119 |
+
description="This chatbot first searches an FAQ database using FAISS, then uses an OpenVINO 8-bit model to generate a response if no FAQ match is found."
|
120 |
+
)
|
121 |
+
|
122 |
+
with gr.Row():
|
123 |
+
clear_button = gr.Button("🧹 Clear History")
|
124 |
+
|
125 |
+
clear_button.click(fn=clear_history, inputs=[], outputs=[])
|
126 |
|
127 |
if __name__ == "__main__":
|
128 |
+
print("Launching Gradio app...")
|
129 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|