|
import torch |
|
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer |
|
from transformers.trainer_utils import set_seed |
|
from packaging import version |
|
import transformers |
|
from threading import Thread |
|
import random |
|
import os |
|
import gradio as gr |
|
|
|
|
|
DEFAULT_TOP_P = 0.9 |
|
DEFAULT_TOP_K = 80 |
|
DEFAULT_TEMPERATURE = 0.3 |
|
DEFAULT_MAX_NEW_TOKENS = 512 |
|
DEFAULT_SYSTEM_MESSAGE = "" |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
cpu_only = not torch.cuda.is_available() |
|
DEFAULT_CKPT_PATH = "ystemsrx/Qwen2.5-Sex" |
|
|
|
def _load_model_tokenizer(checkpoint_path): |
|
tokenizer = AutoTokenizer.from_pretrained(checkpoint_path, resume_download=True) |
|
device_map = "auto" if torch.cuda.is_available() else "cpu" |
|
torch_dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
checkpoint_path, torch_dtype=torch_dtype, resume_download=True |
|
).eval() |
|
|
|
model.generation_config.max_new_tokens = DEFAULT_MAX_NEW_TOKENS |
|
return model, tokenizer |
|
|
|
def _chat_stream(model, tokenizer, query, history, system_message, top_p, top_k, temperature, max_new_tokens): |
|
conversation = [{'role': 'system', 'content': system_message}] |
|
for query_h, response_h in history: |
|
conversation.append({'role': 'user', 'content': query_h}) |
|
conversation.append({'role': 'assistant', 'content': response_h}) |
|
conversation.append({'role': 'user', 'content': query}) |
|
|
|
if version.parse(transformers.__version__) >= version.parse("4.31"): |
|
text = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) |
|
else: |
|
text = "\n".join([f"{msg['role'].capitalize()}: {msg['content']}" for msg in conversation]) + "\nAssistant:" |
|
|
|
inputs = tokenizer(text, return_tensors="pt").to(DEVICE) |
|
streamer = TextIteratorStreamer(tokenizer=tokenizer, skip_prompt=True, timeout=30.0, skip_special_tokens=True) |
|
|
|
generation_kwargs = dict( |
|
input_ids=inputs["input_ids"], max_new_tokens=max_new_tokens, |
|
top_p=top_p, top_k=top_k, temperature=temperature, do_sample=True, |
|
pad_token_id=tokenizer.eos_token_id, streamer=streamer |
|
) |
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
thread.join(timeout=45) |
|
|
|
assistant_reply = "" |
|
for new_text in streamer: |
|
assistant_reply += new_text |
|
yield assistant_reply |
|
|
|
def initialize_model(checkpoint_path=DEFAULT_CKPT_PATH): |
|
set_seed(random.randint(0, 2**32 - 1)) |
|
return _load_model_tokenizer(checkpoint_path) |
|
|
|
model, tokenizer = initialize_model() |
|
|
|
def chat_interface(user_input, history, system_message, top_p, top_k, temperature, max_new_tokens): |
|
if not user_input.strip(): |
|
yield history, history, system_message, "" |
|
return |
|
|
|
history.append((user_input, "")) |
|
yield history, history, system_message, "" |
|
|
|
generator = _chat_stream(model, tokenizer, user_input, history[:-1], system_message, top_p, top_k, temperature, max_new_tokens) |
|
assistant_reply = "" |
|
for new_text in generator: |
|
assistant_reply += new_text |
|
history[-1] = (user_input, assistant_reply) |
|
yield history, history, system_message, "" |
|
|
|
def clear_history(): |
|
return [], [], DEFAULT_SYSTEM_MESSAGE, gr.Textbox.update(value="") |
|
|
|
|
|
demo = gr.Blocks() |
|
with demo: |
|
gr.Markdown("# Qwen2.5 Chatbot") |
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot() |
|
user_input = gr.Textbox(show_label=False, placeholder="输入你的问题...") |
|
send_btn = gr.Button("发送") |
|
|
|
with gr.Column(scale=1): |
|
clear_btn = gr.Button("清空历史") |
|
system_message = gr.Textbox(label="系统消息", value=DEFAULT_SYSTEM_MESSAGE) |
|
top_p_slider = gr.Slider(0.1, 1.0, value=DEFAULT_TOP_P, label="Top-p") |
|
top_k_slider = gr.Slider(0, 100, value=DEFAULT_TOP_K, label="Top-k") |
|
temperature_slider = gr.Slider(0.1, 1.5, value=DEFAULT_TEMPERATURE, label="Temperature") |
|
max_new_tokens_slider = gr.Slider(50, 2048, value=DEFAULT_MAX_NEW_TOKENS, label="Max New Tokens") |
|
|
|
state = gr.State([]) |
|
|
|
user_input.submit(chat_interface, [user_input, state, system_message, top_p_slider, top_k_slider, temperature_slider, max_new_tokens_slider], [chatbot, state, system_message, user_input], queue=True) |
|
send_btn.click(chat_interface, [user_input, state, system_message, top_p_slider, top_k_slider, temperature_slider, max_new_tokens_slider], [chatbot, state, system_message, user_input], queue=True) |
|
clear_btn.click(clear_history, None, [chatbot, state, system_message, user_input], queue=True) |
|
|
|
demo.launch() |
|
|