hsuwill000's picture
Update app.py
b62bde5 verified
import os
import subprocess
import gradio as gr
from transformers import AutoTokenizer
from optimum.intel.openvino import OVModelForCausalLM
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import warnings
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message="__array__ implementation doesn't accept a copy keyword"
)
# 設定模型 ID 與轉換後存檔路徑(8-bit 量化版)
model_id = "agentica-org/DeepScaleR-1.5B-Preview"
export_path = "exported_model_openvino_int8"
print("Loading model as OpenVINO int8 (8-bit) model...")
if os.path.exists(export_path) and os.listdir(export_path):
print(f"Found quantized OpenVINO model at '{export_path}', loading it...")
model = OVModelForCausalLM.from_pretrained(export_path, device_map="auto", use_cache=False)
else:
print("No quantized model found, exporting and quantizing to OpenVINO int8 now...")
# 透過 optimum-cli 導出並量化模型(此命令行參數根據你的任務可能需要調整)
command = [
"optimum-cli", "export", "openvino",
"--model", model_id,
"--task", "text-generation",
"--weight-format", "int8",
export_path
]
subprocess.run(command, check=True)
print(f"Quantized model saved to '{export_path}'.")
model = OVModelForCausalLM.from_pretrained(export_path, device_map="auto", use_cache=False)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
# 載入向量模型(用於將文本轉換為向量)
encoder = SentenceTransformer("all-MiniLM-L6-v2")
# FAQ 知識庫(問題 + 回答)
faq_data = [
("What is FAISS?", "FAISS is a library for efficient similarity search and clustering of dense vectors."),
("How does FAISS work?", "FAISS uses indexing structures to quickly retrieve the nearest neighbors of a query vector."),
("Can FAISS run on GPU?", "Yes, FAISS supports GPU acceleration for faster computation."),
("What is OpenVINO?", "OpenVINO is an inference engine optimized for Intel hardware."),
("How to fine-tune a model?", "Fine-tuning involves training a model on a specific dataset to adapt it to a particular task."),
("What is the best way to optimize inference speed?", "Using quantization and model distillation can significantly improve inference speed.")
]
# 將 FAQ 問題轉換為向量
faq_questions = [q for q, _ in faq_data]
faq_answers = [a for _, a in faq_data]
faq_vectors = np.array(encoder.encode(faq_questions)).astype("float32")
# 建立 FAISS 索引(使用 L2 距離)
d = faq_vectors.shape[1] # 向量維度
index = faiss.IndexFlatL2(d)
index.add(faq_vectors)
# 對話歷史記錄
history = []
# 查詢函數:先嘗試從 FAQ 中檢索答案,若無匹配則使用 OpenVINO 模型生成回答
def respond(prompt):
global history
# 將輸入轉換為向量,並使用 FAISS 查詢最相近的 FAQ 問題
query_vector = np.array(encoder.encode([prompt])).astype("float32")
D, I = index.search(query_vector, 1)
if D[0][0] < 1.0:
response = faq_answers[I[0][0]]
else:
# 若 FAQ 無匹配,則使用 OpenVINO 模型生成回答
messages = [{"role": "system", "content": "Answer the question in English only."}]
for user_text, assistant_text in history:
messages.append({"role": "user", "content": user_text})
messages.append({"role": "assistant", "content": assistant_text})
messages.append({"role": "user", "content": prompt})
# 將對話訊息組成一個 prompt(以換行分隔)
chat_prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])
model_inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device)
generated_ids = model.generate(
**model_inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True
)
response = tokenizer.decode(generated_ids[0], skip_special_tokens=True).strip()
history.append((prompt, response))
return response
# 清除對話歷史記錄
def clear_history():
global history
history = []
return "History cleared!"
# 建立 Gradio 介面
with gr.Blocks() as demo:
gr.Markdown("# DeepScaleR-1.5B-Preview (OpenVINO int8) Chatbot with FAISS FAQ_You must first copy it to your own SPACE before you can use it.")
with gr.Tabs():
with gr.TabItem("Chat"):
chat_interface = gr.Interface(
fn=respond,
inputs=gr.Textbox(label="Prompt", placeholder="Enter your message..."),
outputs=gr.Textbox(label="Response", interactive=False),
api_name="hchat",
title="DeepScaleR-1.5B-Preview (OpenVINO int8) Chatbot",
description="This chatbot first searches an FAQ database using FAISS, then uses an OpenVINO 8-bit model to generate a response if no FAQ match is found."
)
with gr.Row():
clear_button = gr.Button("🧹 Clear History")
clear_button.click(fn=clear_history, inputs=[], outputs=[])
if __name__ == "__main__":
print("Launching Gradio app...")
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)