|
import gradio as gr |
|
import onnxruntime_genai as og |
|
import time |
|
import os |
|
from huggingface_hub import snapshot_download |
|
import argparse |
|
import logging |
|
import numpy as np |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
|
|
MODEL_REPO = "microsoft/Phi-4-mini-instruct-onnx" |
|
|
|
|
|
EXECUTION_PROVIDER = "cpu" |
|
MODEL_VARIANT_GLOB = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/*" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LOCAL_MODEL_DIR = "./phi4-mini-onnx-model" |
|
HF_LOGO_URL = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg" |
|
HF_MODEL_URL = f"https://huggingface.co/{MODEL_REPO}" |
|
ORT_GENAI_URL = "https://github.com/microsoft/onnxruntime-genai" |
|
PHI_LOGO_URL = "https://microsoft.github.io/phi/assets/img/logo-final.png" |
|
|
|
|
|
model = None |
|
tokenizer = None |
|
model_variant_name = os.path.basename(os.path.dirname(MODEL_VARIANT_GLOB)) |
|
model_status = "Initializing..." |
|
|
|
|
|
def initialize_model(): |
|
"""Downloads and loads the ONNX model and tokenizer.""" |
|
global model, tokenizer, model_status |
|
logging.info("--- Initializing ONNX Runtime GenAI ---") |
|
model_status = "Downloading model..." |
|
logging.info(model_status) |
|
|
|
|
|
model_variant_dir = os.path.join(LOCAL_MODEL_DIR, os.path.dirname(MODEL_VARIANT_GLOB)) |
|
if os.path.exists(model_variant_dir) and os.listdir(model_variant_dir): |
|
logging.info(f"Model variant found in {model_variant_dir}. Skipping download.") |
|
model_path = model_variant_dir |
|
else: |
|
logging.info(f"Downloading model variant '{MODEL_VARIANT_GLOB}' from {MODEL_REPO}...") |
|
try: |
|
snapshot_download( |
|
MODEL_REPO, |
|
allow_patterns=[MODEL_VARIANT_GLOB], |
|
local_dir=LOCAL_MODEL_DIR, |
|
local_dir_use_symlinks=False |
|
) |
|
model_path = model_variant_dir |
|
logging.info(f"Model downloaded to: {model_path}") |
|
except Exception as e: |
|
logging.error(f"Error downloading model: {e}", exc_info=True) |
|
model_status = f"Error downloading model: {e}" |
|
raise RuntimeError(f"Failed to download model: {e}") |
|
|
|
|
|
model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..." |
|
logging.info(model_status) |
|
try: |
|
|
|
logging.info(f"Using provider based on installed package (expecting: {EXECUTION_PROVIDER})") |
|
model = og.Model(model_path) |
|
tokenizer = og.Tokenizer(model) |
|
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" |
|
logging.info("Model and Tokenizer loaded successfully.") |
|
except AttributeError as ae: |
|
logging.error(f"AttributeError during model/tokenizer init: {ae}", exc_info=True) |
|
logging.error("This might indicate an installation issue or version incompatibility with onnxruntime_genai.") |
|
model_status = f"Init Error: {ae}" |
|
raise RuntimeError(f"Failed to initialize model/tokenizer: {ae}") |
|
except Exception as e: |
|
logging.error(f"Error loading model or tokenizer: {e}", exc_info=True) |
|
model_status = f"Error loading model: {e}" |
|
raise RuntimeError(f"Failed to load model: {e}") |
|
|
|
|
|
def generate_response_stream(prompt, history, max_length, temperature, top_p, top_k): |
|
"""Generates a response using the Phi-4 ONNX model, yielding text chunks.""" |
|
global model_status |
|
if not model or not tokenizer: |
|
model_status = "Error: Model not initialized!" |
|
yield "Error: Model not initialized. Please check logs." |
|
return |
|
|
|
|
|
full_prompt = "" |
|
|
|
for user_msg, assistant_msg in history: |
|
full_prompt += f"<|user|>\n{user_msg}<|end|>\n" |
|
if assistant_msg: |
|
full_prompt += f"<|assistant|>\n{assistant_msg}<|end|>\n" |
|
|
|
|
|
full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n" |
|
|
|
logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})") |
|
|
|
try: |
|
input_tokens_list = tokenizer.encode(full_prompt) |
|
|
|
input_tokens = np.array(input_tokens_list, dtype=np.int32) |
|
|
|
input_tokens = input_tokens.reshape((1, -1)) |
|
|
|
|
|
search_options = { |
|
"max_length": max_length, |
|
"temperature": temperature, |
|
"top_p": top_p, |
|
"top_k": top_k, |
|
"do_sample": True, |
|
} |
|
|
|
params = og.GeneratorParams(model) |
|
params.set_search_options(**search_options) |
|
|
|
|
|
|
|
|
|
inputs = {"input_ids": input_tokens} |
|
logging.info(f"Setting inputs with keys: {inputs.keys()} and shape for 'input_ids': {inputs['input_ids'].shape}") |
|
params.set_inputs(inputs) |
|
|
|
start_time = time.time() |
|
|
|
generator = og.Generator(model, params) |
|
model_status = "Generating..." |
|
logging.info("Streaming response...") |
|
|
|
first_token_time = None |
|
token_count = 0 |
|
|
|
while not generator.is_done(): |
|
try: |
|
generator.compute_logits() |
|
generator.generate_next_token() |
|
if first_token_time is None: |
|
first_token_time = time.time() |
|
|
|
next_token = generator.get_next_tokens()[0] |
|
|
|
decoded_chunk = tokenizer.decode([next_token]) |
|
token_count += 1 |
|
|
|
|
|
if decoded_chunk == "<|end|>": |
|
logging.info("Assistant explicitly generated <|end|> token string.") |
|
break |
|
|
|
yield decoded_chunk |
|
except Exception as loop_error: |
|
logging.error(f"Error inside generation loop: {loop_error}", exc_info=True) |
|
yield f"\n\nError during token generation: {loop_error}" |
|
break |
|
|
|
end_time = time.time() |
|
ttft = (first_token_time - start_time) * 1000 if first_token_time else -1 |
|
total_time = end_time - start_time |
|
tps = (token_count / total_time) if total_time > 0 else 0 |
|
|
|
logging.info(f"Generation complete. Tokens: {token_count}, Total Time: {total_time:.2f}s, TTFT: {ttft:.2f}ms, TPS: {tps:.2f}") |
|
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" |
|
|
|
except TypeError as te: |
|
|
|
logging.error(f"TypeError during generation setup: {te}", exc_info=True) |
|
logging.error("Check if the input format {'input_ids': token_array} is correct.") |
|
model_status = f"Generation Setup TypeError: {te}" |
|
yield f"\n\nSorry, a TypeError occurred setting up generation: {te}" |
|
except AttributeError as ae: |
|
|
|
logging.error(f"AttributeError during generation setup: {ae}", exc_info=True) |
|
model_status = f"Generation Setup Error: {ae}" |
|
yield f"\n\nSorry, an error occurred setting up generation: {ae}" |
|
except Exception as e: |
|
logging.error(f"Error during generation: {e}", exc_info=True) |
|
model_status = f"Error during generation: {e}" |
|
yield f"\n\nSorry, an error occurred during generation: {e}" |
|
|
|
|
|
|
|
|
|
|
|
def add_user_message(user_message, history): |
|
"""Adds the user's message to the chat history for display.""" |
|
if not user_message: |
|
return "", history |
|
history = history + [[user_message, None]] |
|
return "", history |
|
|
|
|
|
def generate_bot_response(history, max_length, temperature, top_p, top_k): |
|
"""Generates the bot's response based on the history and streams it.""" |
|
if not history or history[-1][1] is not None: |
|
return history |
|
|
|
user_prompt = history[-1][0] |
|
model_history = history[:-1] |
|
|
|
response_stream = generate_response_stream( |
|
user_prompt, model_history, max_length, temperature, top_p, top_k |
|
) |
|
|
|
history[-1][1] = "" |
|
for chunk in response_stream: |
|
history[-1][1] += chunk |
|
yield history |
|
|
|
|
|
def clear_chat(): |
|
"""Clears the chat history and input.""" |
|
global model_status |
|
if model and tokenizer and not model_status.startswith("Error") and not model_status.startswith("FATAL"): |
|
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" |
|
return None, [], model_status |
|
|
|
|
|
|
|
try: |
|
initialize_model() |
|
except Exception as e: |
|
print(f"FATAL: Model initialization failed: {e}") |
|
|
|
|
|
|
|
logging.info("Creating Gradio Interface...") |
|
|
|
theme = gr.themes.Soft( |
|
primary_hue="blue", |
|
secondary_hue="sky", |
|
neutral_hue="slate", |
|
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"], |
|
) |
|
|
|
with gr.Blocks(theme=theme, title="Phi-4 Mini ONNX Chat") as demo: |
|
|
|
with gr.Row(equal_height=False): |
|
with gr.Column(scale=3): |
|
gr.Markdown(f""" |
|
# Phi-4 Mini Instruct ONNX Chat 🤖 |
|
Interact with the quantized `{model_variant_name}` version of [`{MODEL_REPO}`]({HF_MODEL_URL}) |
|
running efficiently via [`onnxruntime-genai`]({ORT_GENAI_URL}) ({EXECUTION_PROVIDER.upper()}). |
|
""") |
|
with gr.Column(scale=1, min_width=150): |
|
gr.Image(HF_LOGO_URL, elem_id="hf-logo", show_label=False, show_download_button=False, container=False, height=50) |
|
model_status_text = gr.Textbox(value=model_status, label="Model Status", interactive=False, max_lines=2) |
|
|
|
|
|
with gr.Row(): |
|
|
|
with gr.Column(scale=3): |
|
chatbot = gr.Chatbot( |
|
label="Conversation", |
|
height=600, |
|
layout="bubble", |
|
bubble_full_width=False, |
|
avatar_images=(None, PHI_LOGO_URL) |
|
) |
|
with gr.Row(): |
|
prompt_input = gr.Textbox( |
|
label="Your Message", |
|
placeholder="<|user|>\nType your message here...\n<|end|>", |
|
lines=4, |
|
scale=9 |
|
) |
|
with gr.Column(scale=1, min_width=120): |
|
submit_button = gr.Button("Send", variant="primary", size="lg") |
|
clear_button = gr.Button("🗑️ Clear Chat", variant="secondary") |
|
|
|
|
|
with gr.Column(scale=1, min_width=250): |
|
gr.Markdown("### ⚙️ Generation Settings") |
|
with gr.Group(): |
|
max_length = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Length", info="Max tokens in response.") |
|
temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="0.0 = deterministic\n>1.0 = more random") |
|
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling probability.") |
|
top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Limit to K most likely tokens (0=disable).") |
|
gr.Markdown("---") |
|
gr.Markdown("ℹ️ **Note:** Uses Phi-4 instruction format: \n`<|user|>\nPROMPT<|end|>\n<|assistant|>`") |
|
gr.Markdown(f"Running on **{EXECUTION_PROVIDER.upper()}**.") |
|
|
|
|
|
bot_response_inputs = [chatbot, max_length, temperature, top_p, top_k] |
|
|
|
submit_event = prompt_input.submit( |
|
fn=add_user_message, |
|
inputs=[prompt_input, chatbot], |
|
outputs=[prompt_input, chatbot], |
|
queue=False, |
|
).then( |
|
fn=generate_bot_response, |
|
inputs=bot_response_inputs, |
|
outputs=[chatbot], |
|
api_name="chat" |
|
) |
|
|
|
submit_button.click( |
|
fn=add_user_message, |
|
inputs=[prompt_input, chatbot], |
|
outputs=[prompt_input, chatbot], |
|
queue=False, |
|
).then( |
|
fn=generate_bot_response, |
|
inputs=bot_response_inputs, |
|
outputs=[chatbot], |
|
api_name=False |
|
) |
|
|
|
clear_button.click( |
|
fn=clear_chat, |
|
inputs=None, |
|
outputs=[prompt_input, chatbot, model_status_text], |
|
queue=False |
|
) |
|
|
|
|
|
logging.info("Launching Gradio App...") |
|
demo.queue(max_size=20) |
|
demo.launch(show_error=True, max_threads=40) |