Yyyy / app.py
Athspi's picture
Update app.py
e3d67e9 verified
raw
history blame
14.6 kB
import gradio as gr
import onnxruntime_genai as og
import time
import os
from huggingface_hub import snapshot_download
import argparse
import logging
import numpy as np # Import numpy
# --- Logging Setup ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- Configuration ---
MODEL_REPO = "microsoft/Phi-4-mini-instruct-onnx"
# --- Defaulting to CPU INT4 for Hugging Face Spaces ---
EXECUTION_PROVIDER = "cpu" # Corresponds to installing 'onnxruntime-genai'
MODEL_VARIANT_GLOB = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/*"
# --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
# --- (Optional) Alternative GPU Configuration ---
# EXECUTION_PROVIDER = "cuda" # Corresponds to installing 'onnxruntime-genai-cuda'
# MODEL_VARIANT_GLOB = "gpu/gpu-int4-rtn-block-32/*"
# --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
LOCAL_MODEL_DIR = "./phi4-mini-onnx-model" # Directory within the Space
HF_LOGO_URL = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"
HF_MODEL_URL = f"https://huggingface.co/{MODEL_REPO}"
ORT_GENAI_URL = "https://github.com/microsoft/onnxruntime-genai"
PHI_LOGO_URL = "https://microsoft.github.io/phi/assets/img/logo-final.png" # Phi logo for bot avatar
# Global variables for model and tokenizer
model = None
tokenizer = None
model_variant_name = os.path.basename(os.path.dirname(MODEL_VARIANT_GLOB)) # For display
model_status = "Initializing..."
# --- Model Download and Load ---
def initialize_model():
"""Downloads and loads the ONNX model and tokenizer."""
global model, tokenizer, model_status
logging.info("--- Initializing ONNX Runtime GenAI ---")
model_status = "Downloading model..."
logging.info(model_status)
# --- Download ---
model_variant_dir = os.path.join(LOCAL_MODEL_DIR, os.path.dirname(MODEL_VARIANT_GLOB))
if os.path.exists(model_variant_dir) and os.listdir(model_variant_dir):
logging.info(f"Model variant found in {model_variant_dir}. Skipping download.")
model_path = model_variant_dir
else:
logging.info(f"Downloading model variant '{MODEL_VARIANT_GLOB}' from {MODEL_REPO}...")
try:
snapshot_download(
MODEL_REPO,
allow_patterns=[MODEL_VARIANT_GLOB],
local_dir=LOCAL_MODEL_DIR,
local_dir_use_symlinks=False
)
model_path = model_variant_dir
logging.info(f"Model downloaded to: {model_path}")
except Exception as e:
logging.error(f"Error downloading model: {e}", exc_info=True)
model_status = f"Error downloading model: {e}"
raise RuntimeError(f"Failed to download model: {e}")
# --- Load ---
model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..."
logging.info(model_status)
try:
# The simple constructor often works by detecting the installed ORT package.
logging.info(f"Using provider based on installed package (expecting: {EXECUTION_PROVIDER})")
model = og.Model(model_path) # Simplified model loading
tokenizer = og.Tokenizer(model)
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})"
logging.info("Model and Tokenizer loaded successfully.")
except AttributeError as ae:
logging.error(f"AttributeError during model/tokenizer init: {ae}", exc_info=True)
logging.error("This might indicate an installation issue or version incompatibility with onnxruntime_genai.")
model_status = f"Init Error: {ae}"
raise RuntimeError(f"Failed to initialize model/tokenizer: {ae}")
except Exception as e:
logging.error(f"Error loading model or tokenizer: {e}", exc_info=True)
model_status = f"Error loading model: {e}"
raise RuntimeError(f"Failed to load model: {e}")
# --- Generation Function (Core Logic) ---
def generate_response_stream(prompt, history, max_length, temperature, top_p, top_k):
"""Generates a response using the Phi-4 ONNX model, yielding text chunks."""
global model_status
if not model or not tokenizer:
model_status = "Error: Model not initialized!"
yield "Error: Model not initialized. Please check logs."
return
# --- Prepare the prompt using the Phi-4 instruct format ---
full_prompt = ""
# History format is [[user1, bot1], [user2, bot2], ...]
for user_msg, assistant_msg in history: # history here is *before* the current prompt
full_prompt += f"<|user|>\n{user_msg}<|end|>\n"
if assistant_msg: # Append assistant message only if it exists
full_prompt += f"<|assistant|>\n{assistant_msg}<|end|>\n"
# Add the current user prompt and the trigger for the assistant's response
full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})")
try:
input_tokens_list = tokenizer.encode(full_prompt) # Encode returns a list/array
# Ensure input_tokens is a numpy array of the correct type (int32 is common)
input_tokens = np.array(input_tokens_list, dtype=np.int32)
# Reshape to (batch_size, sequence_length), which is (1, N) for single prompt
input_tokens = input_tokens.reshape((1, -1))
search_options = {
"max_length": max_length,
"temperature": temperature,
"top_p": top_p,
"top_k": top_k,
"do_sample": True,
}
params = og.GeneratorParams(model)
params.set_search_options(**search_options)
# FIX: Create a dictionary mapping input names to tensors (numpy arrays)
# and pass this dictionary to set_inputs.
# Assuming the standard input name "input_ids".
inputs = {"input_ids": input_tokens}
logging.info(f"Setting inputs with keys: {inputs.keys()} and shape for 'input_ids': {inputs['input_ids'].shape}")
params.set_inputs(inputs)
start_time = time.time()
# Create generator AFTER setting parameters including inputs
generator = og.Generator(model, params)
model_status = "Generating..." # Update status indicator
logging.info("Streaming response...")
first_token_time = None
token_count = 0
# Rely primarily on generator.is_done()
while not generator.is_done():
try:
generator.compute_logits()
generator.generate_next_token()
if first_token_time is None:
first_token_time = time.time() # Record time to first token
next_token = generator.get_next_tokens()[0]
decoded_chunk = tokenizer.decode([next_token])
token_count += 1
# Secondary check: Stop if the model explicitly generates the <|end|> string literal.
if decoded_chunk == "<|end|>":
logging.info("Assistant explicitly generated <|end|> token string.")
break
yield decoded_chunk # Yield just the text chunk
except Exception as loop_error:
logging.error(f"Error inside generation loop: {loop_error}", exc_info=True)
yield f"\n\nError during token generation: {loop_error}"
break # Exit loop on error
end_time = time.time()
ttft = (first_token_time - start_time) * 1000 if first_token_time else -1
total_time = end_time - start_time
tps = (token_count / total_time) if total_time > 0 else 0
logging.info(f"Generation complete. Tokens: {token_count}, Total Time: {total_time:.2f}s, TTFT: {ttft:.2f}ms, TPS: {tps:.2f}")
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" # Reset status
except TypeError as te:
# Catch type errors specifically during setup if the input format is still wrong
logging.error(f"TypeError during generation setup: {te}", exc_info=True)
logging.error("Check if the input format {'input_ids': token_array} is correct.")
model_status = f"Generation Setup TypeError: {te}"
yield f"\n\nSorry, a TypeError occurred setting up generation: {te}"
except AttributeError as ae:
# Catch potential future API changes or issues during generation setup
logging.error(f"AttributeError during generation setup: {ae}", exc_info=True)
model_status = f"Generation Setup Error: {ae}"
yield f"\n\nSorry, an error occurred setting up generation: {ae}"
except Exception as e:
logging.error(f"Error during generation: {e}", exc_info=True)
model_status = f"Error during generation: {e}"
yield f"\n\nSorry, an error occurred during generation: {e}" # Yield error message
# --- Gradio Interface Functions ---
# 1. Function to add user message to chat history
def add_user_message(user_message, history):
"""Adds the user's message to the chat history for display."""
if not user_message:
return "", history # Clear input, return unchanged history
history = history + [[user_message, None]] # Append user message, leave bot response None
return "", history # Clear input textbox, return updated history
# 2. Function to handle bot response generation and streaming
def generate_bot_response(history, max_length, temperature, top_p, top_k):
"""Generates the bot's response based on the history and streams it."""
if not history or history[-1][1] is not None:
return history
user_prompt = history[-1][0] # Get the latest user prompt
model_history = history[:-1] # Prepare history for the model
response_stream = generate_response_stream(
user_prompt, model_history, max_length, temperature, top_p, top_k
)
history[-1][1] = "" # Initialize the bot response string in the history
for chunk in response_stream:
history[-1][1] += chunk # Append the chunk to the bot's message in history
yield history # Yield the *entire updated history* back to Chatbot
# 3. Function to clear chat
def clear_chat():
"""Clears the chat history and input."""
global model_status
if model and tokenizer and not model_status.startswith("Error") and not model_status.startswith("FATAL"):
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})"
return None, [], model_status # Clear Textbox, Chatbot history, and update status display
# --- Initialize Model on App Start ---
try:
initialize_model()
except Exception as e:
print(f"FATAL: Model initialization failed: {e}")
# --- Gradio Interface ---
logging.info("Creating Gradio Interface...")
theme = gr.themes.Soft(
primary_hue="blue",
secondary_hue="sky",
neutral_hue="slate",
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
)
with gr.Blocks(theme=theme, title="Phi-4 Mini ONNX Chat") as demo:
# Header Section
with gr.Row(equal_height=False):
with gr.Column(scale=3):
gr.Markdown(f"""
# Phi-4 Mini Instruct ONNX Chat 🤖
Interact with the quantized `{model_variant_name}` version of [`{MODEL_REPO}`]({HF_MODEL_URL})
running efficiently via [`onnxruntime-genai`]({ORT_GENAI_URL}) ({EXECUTION_PROVIDER.upper()}).
""")
with gr.Column(scale=1, min_width=150):
gr.Image(HF_LOGO_URL, elem_id="hf-logo", show_label=False, show_download_button=False, container=False, height=50)
model_status_text = gr.Textbox(value=model_status, label="Model Status", interactive=False, max_lines=2)
# Main Layout
with gr.Row():
# Chat Column
with gr.Column(scale=3):
chatbot = gr.Chatbot(
label="Conversation",
height=600,
layout="bubble",
bubble_full_width=False,
avatar_images=(None, PHI_LOGO_URL)
)
with gr.Row():
prompt_input = gr.Textbox(
label="Your Message",
placeholder="<|user|>\nType your message here...\n<|end|>",
lines=4,
scale=9
)
with gr.Column(scale=1, min_width=120):
submit_button = gr.Button("Send", variant="primary", size="lg")
clear_button = gr.Button("🗑️ Clear Chat", variant="secondary")
# Settings Column
with gr.Column(scale=1, min_width=250):
gr.Markdown("### ⚙️ Generation Settings")
with gr.Group():
max_length = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Length", info="Max tokens in response.")
temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="0.0 = deterministic\n>1.0 = more random")
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling probability.")
top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Limit to K most likely tokens (0=disable).")
gr.Markdown("---")
gr.Markdown("ℹ️ **Note:** Uses Phi-4 instruction format: \n`<|user|>\nPROMPT<|end|>\n<|assistant|>`")
gr.Markdown(f"Running on **{EXECUTION_PROVIDER.upper()}**.")
# Event Listeners
bot_response_inputs = [chatbot, max_length, temperature, top_p, top_k]
submit_event = prompt_input.submit(
fn=add_user_message,
inputs=[prompt_input, chatbot],
outputs=[prompt_input, chatbot],
queue=False,
).then(
fn=generate_bot_response,
inputs=bot_response_inputs,
outputs=[chatbot],
api_name="chat"
)
submit_button.click(
fn=add_user_message,
inputs=[prompt_input, chatbot],
outputs=[prompt_input, chatbot],
queue=False,
).then(
fn=generate_bot_response,
inputs=bot_response_inputs,
outputs=[chatbot],
api_name=False
)
clear_button.click(
fn=clear_chat,
inputs=None,
outputs=[prompt_input, chatbot, model_status_text],
queue=False
)
# Launch the Gradio app
logging.info("Launching Gradio App...")
demo.queue(max_size=20)
demo.launch(show_error=True, max_threads=40)