Qwen-test / app.py
henry
7
0096d6d
raw
history blame
4.78 kB
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers.trainer_utils import set_seed
from packaging import version
import transformers
from threading import Thread
import random
import os
import gradio as gr
# 默认参数
DEFAULT_TOP_P = 0.9
DEFAULT_TOP_K = 80
DEFAULT_TEMPERATURE = 0.3
DEFAULT_MAX_NEW_TOKENS = 512
DEFAULT_SYSTEM_MESSAGE = ""
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
cpu_only = not torch.cuda.is_available()
DEFAULT_CKPT_PATH = "ystemsrx/Qwen2.5-Sex"
def _load_model_tokenizer(checkpoint_path):
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, resume_download=True)
device_map = "auto" if torch.cuda.is_available() else "cpu"
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(
checkpoint_path, torch_dtype=torch_dtype, resume_download=True
).eval()
model.generation_config.max_new_tokens = DEFAULT_MAX_NEW_TOKENS
return model, tokenizer
def _chat_stream(model, tokenizer, query, history, system_message, top_p, top_k, temperature, max_new_tokens):
conversation = [{'role': 'system', 'content': system_message}]
for query_h, response_h in history:
conversation.append({'role': 'user', 'content': query_h})
conversation.append({'role': 'assistant', 'content': response_h})
conversation.append({'role': 'user', 'content': query})
if version.parse(transformers.__version__) >= version.parse("4.31"):
text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
else:
text = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in conversation]) + "\nAssistant:"
inputs = tokenizer(text, return_tensors="pt").to(DEVICE)
streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=True, timeout=30.0, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=inputs["input_ids"], max_new_tokens=max_new_tokens,
top_p=top_p, top_k=top_k, temperature=temperature, do_sample=True,
pad_token_id=tokenizer.eos_token_id, streamer=streamer
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
thread.join(timeout=45)
assistant_reply = ""
for new_text in streamer:
assistant_reply += new_text
yield assistant_reply
def initialize_model(checkpoint_path=DEFAULT_CKPT_PATH):
set_seed(random.randint(0, 2**32 - 1))
return _load_model_tokenizer(checkpoint_path)
model, tokenizer = initialize_model()
def chat_interface(user_input, history, system_message, top_p, top_k, temperature, max_new_tokens):
if not user_input.strip():
yield history, history, system_message, ""
return
history.append((user_input, ""))
yield history, history, system_message, ""
generator = _chat_stream(model, tokenizer, user_input, history[:-1], system_message, top_p, top_k, temperature, max_new_tokens)
assistant_reply = ""
for new_text in generator:
assistant_reply += new_text
history[-1] = (user_input, assistant_reply)
yield history, history, system_message, ""
def clear_history():
return [], [], DEFAULT_SYSTEM_MESSAGE, gr.Textbox.update(value="")
# Gradio UI
demo = gr.Blocks()
with demo:
gr.Markdown("# Qwen2.5 Chatbot")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot()
user_input = gr.Textbox(show_label=False, placeholder="输入你的问题...")
send_btn = gr.Button("发送")
with gr.Column(scale=1):
clear_btn = gr.Button("清空历史")
system_message = gr.Textbox(label="系统消息", value=DEFAULT_SYSTEM_MESSAGE)
top_p_slider = gr.Slider(0.1, 1.0, value=DEFAULT_TOP_P, label="Top-p")
top_k_slider = gr.Slider(0, 100, value=DEFAULT_TOP_K, label="Top-k")
temperature_slider = gr.Slider(0.1, 1.5, value=DEFAULT_TEMPERATURE, label="Temperature")
max_new_tokens_slider = gr.Slider(50, 2048, value=DEFAULT_MAX_NEW_TOKENS, label="Max New Tokens")
state = gr.State([])
user_input.submit(chat_interface, [user_input, state, system_message, top_p_slider, top_k_slider, temperature_slider, max_new_tokens_slider], [chatbot, state, system_message, user_input], queue=True)
send_btn.click(chat_interface, [user_input, state, system_message, top_p_slider, top_k_slider, temperature_slider, max_new_tokens_slider], [chatbot, state, system_message, user_input], queue=True)
clear_btn.click(clear_history, None, [chatbot, state, system_message, user_input], queue=True)
demo.launch()