File size: 7,490 Bytes
fa7e3c5 d9aee41 fa7e3c5 0d68ad6 fa7e3c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 |
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
from threading import Thread
phi4_model_path = "microsoft/phi-4"
phi4_mini_model_path = "microsoft/Phi-4-mini-instruct"
device = "cuda:0" if torch.cuda.is_available() else "cpu"
phi4_model = AutoModelForCausalLM.from_pretrained(phi4_model_path, torch_dtype="auto").to(device)
phi4_tokenizer = AutoTokenizer.from_pretrained(phi4_model_path)
phi4_mini_model = AutoModelForCausalLM.from_pretrained(phi4_mini_model_path, torch_dtype="auto").to(device)
phi4_mini_tokenizer = AutoTokenizer.from_pretrained(phi4_mini_model_path)
@spaces.GPU(duration=60)
def generate_response(user_message, model_name, max_tokens, temperature, top_k, top_p, repetition_penalty, history_state):
if not user_message.strip():
return history_state, history_state
# Select models
if model_name == "Phi-4":
model = phi4_model
tokenizer = phi4_tokenizer
start_tag = "<|im_start|>"
sep_tag = "<|im_sep|>"
end_tag = "<|im_end|>"
elif model_name == "Phi-4-mini-instruct":
model = phi4_mini_model
tokenizer = phi4_mini_tokenizer
start_tag = ""
sep_tag = ""
end_tag = "<|end|>"
else:
raise ValueError("Error loading on models")
# Recommended prompt settings by Microsoft
system_message = "You are a friendly and knowledgeable assistant, here to help with any questions or tasks."
if model_name == "Phi-4":
prompt = f"{start_tag}system{sep_tag}{system_message}{end_tag}"
for message in history_state:
if message["role"] == "user":
prompt += f"{start_tag}user{sep_tag}{message['content']}{end_tag}"
elif message["role"] == "assistant" and message["content"]:
prompt += f"{start_tag}assistant{sep_tag}{message['content']}{end_tag}"
prompt += f"{start_tag}user{sep_tag}{user_message}{end_tag}{start_tag}assistant{sep_tag}"
else:
prompt = f"<|system|>{system_message}{end_tag}"
for message in history_state:
if message["role"] == "user":
prompt += f"<|user|>{message['content']}{end_tag}"
elif message["role"] == "assistant" and message["content"]:
prompt += f"<|assistant|>{message['content']}{end_tag}"
prompt += f"<|user|>{user_message}{end_tag}<|assistant|>"
inputs = tokenizer(prompt, return_tensors="pt").to(device)
do_sample = not (temperature == 1.0 and top_k >= 100 and top_p == 1.0)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
# sampling techniques
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"max_new_tokens": int(max_tokens),
"do_sample": do_sample,
"temperature": temperature,
"top_k": int(top_k),
"top_p": top_p,
"repetition_penalty": repetition_penalty,
"streamer": streamer,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Stream the response
assistant_response = ""
new_history = history_state + [
{"role": "user", "content": user_message},
{"role": "assistant", "content": ""}
]
for new_token in streamer:
cleaned_token = new_token.replace("<|im_start|>", "").replace("<|im_sep|>", "").replace("<|im_end|>", "").replace("<|end|>", "").replace("<|system|>", "").replace("<|user|>", "").replace("<|assistant|>", "")
assistant_response += cleaned_token
new_history[-1]["content"] = assistant_response.strip()
yield new_history, new_history
yield new_history, new_history
example_messages = {
"Learn about physics": "Explain Newton’s laws of motion.",
"Discover space facts": "What are some interesting facts about black holes?",
"Write a factorial function": "Write a Python function to calculate the factorial of a number."
}
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Phi-4 Models Chatbot
Welcome to the Phi-4 Chatbot! You can chat with Microsoft's Phi-4 or Phi-4-mini-instruct models. Adjust the settings on the left to customize the model's responses.
"""
)
history_state = gr.State([])
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Settings")
model_dropdown = gr.Dropdown(
choices=["Phi-4", "Phi-4-mini-instruct"],
label="Select Model",
value="Phi-4"
)
max_tokens_slider = gr.Slider(
minimum=64,
maximum=4096,
step=50,
value=512,
label="Max Tokens"
)
with gr.Accordion("Advanced Settings", open=False):
temperature_slider = gr.Slider(
minimum=0.1,
maximum=2.0,
value=1.0,
label="Temperature"
)
top_k_slider = gr.Slider(
minimum=1,
maximum=100,
step=1,
value=50,
label="Top-k"
)
top_p_slider = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
label="Top-p"
)
repetition_penalty_slider = gr.Slider(
minimum=1.0,
maximum=2.0,
value=1.0,
label="Repetition Penalty"
)
with gr.Column(scale=4):
chatbot = gr.Chatbot(label="Chat", type="messages")
with gr.Row():
user_input = gr.Textbox(
label="Your message",
placeholder="Type your message here...",
scale=3
)
submit_button = gr.Button("Send", variant="primary", scale=1)
clear_button = gr.Button("Clear", scale=1)
gr.Markdown("**Try these examples:**")
with gr.Row():
example1_button = gr.Button("Learn about physics")
example2_button = gr.Button("Discover space facts")
example3_button = gr.Button("Write a factorial function")
submit_button.click(
fn=generate_response,
inputs=[user_input, model_dropdown, max_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider, history_state],
outputs=[chatbot, history_state]
).then(
fn=lambda: gr.update(value=""),
inputs=None,
outputs=user_input
)
clear_button.click(
fn=lambda: ([], []),
inputs=None,
outputs=[chatbot, history_state]
)
example1_button.click(
fn=lambda: gr.update(value=example_messages["Learn about physics"]),
inputs=None,
outputs=user_input
)
example2_button.click(
fn=lambda: gr.update(value=example_messages["Discover space facts"]),
inputs=None,
outputs=user_input
)
example3_button.click(
fn=lambda: gr.update(value=example_messages["Write a factorial function"]),
inputs=None,
outputs=user_input
)
demo.launch(ssr_mode=False) |