Update app.py
Browse files
app.py
CHANGED
@@ -1,20 +1,205 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
#
|
16 |
-
|
17 |
-
#
|
18 |
-
#
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import onnxruntime_genai as og
|
3 |
+
import time
|
4 |
+
import os
|
5 |
+
from huggingface_hub import snapshot_download
|
6 |
+
import argparse
|
7 |
+
import logging
|
8 |
+
|
9 |
+
# --- Logging Setup ---
|
10 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
11 |
+
|
12 |
+
# --- Configuration ---
|
13 |
+
MODEL_REPO = "microsoft/Phi-4-mini-instruct-onnx"
|
14 |
+
|
15 |
+
# --- Defaulting to CPU INT4 for Hugging Face Spaces ---
|
16 |
+
# Free Spaces generally provide CPU resources.
|
17 |
+
# If deploying on a paid GPU Space, you would change these
|
18 |
+
# and the requirements.txt accordingly.
|
19 |
+
EXECUTION_PROVIDER = "cpu"
|
20 |
+
MODEL_VARIANT_GLOB = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/*"
|
21 |
+
# Ensure requirements.txt lists: onnxruntime-genai
|
22 |
+
# --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
|
23 |
+
|
24 |
+
# --- Alternative GPU Configuration (Requires GPU Space & requirements change) ---
|
25 |
+
# EXECUTION_PROVIDER = "cuda"
|
26 |
+
# MODEL_VARIANT_GLOB = "gpu/gpu-int4-rtn-block-32/*"
|
27 |
+
# Ensure requirements.txt lists: onnxruntime-genai-cuda
|
28 |
+
# --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
|
29 |
+
|
30 |
+
LOCAL_MODEL_DIR = "./phi4-mini-onnx-model" # Directory within the Space
|
31 |
+
HF_LOGO_URL = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg" # Official HF Logo
|
32 |
+
|
33 |
+
# Global variables for model and tokenizer
|
34 |
+
model = None
|
35 |
+
tokenizer = None
|
36 |
+
model_variant_name = os.path.basename(os.path.dirname(MODEL_VARIANT_GLOB)) # For display
|
37 |
+
|
38 |
+
# --- Model Download and Load ---
|
39 |
+
def initialize_model():
|
40 |
+
"""Downloads and loads the ONNX model and tokenizer."""
|
41 |
+
global model, tokenizer
|
42 |
+
logging.info("--- Initializing ONNX Runtime GenAI ---")
|
43 |
+
|
44 |
+
# --- Download ---
|
45 |
+
model_variant_dir = os.path.join(LOCAL_MODEL_DIR, os.path.dirname(MODEL_VARIANT_GLOB))
|
46 |
+
if os.path.exists(model_variant_dir) and os.listdir(model_variant_dir):
|
47 |
+
logging.info(f"Model variant found in {model_variant_dir}. Skipping download.")
|
48 |
+
model_path = model_variant_dir
|
49 |
+
else:
|
50 |
+
logging.info(f"Downloading model variant '{MODEL_VARIANT_GLOB}' from {MODEL_REPO}...")
|
51 |
+
try:
|
52 |
+
# Use cache_dir for potentially faster re-runs if space storage allows caching
|
53 |
+
# cache_dir = os.path.join(LOCAL_MODEL_DIR, ".cache") # Optional: Define cache dir
|
54 |
+
snapshot_download(
|
55 |
+
MODEL_REPO,
|
56 |
+
allow_patterns=[MODEL_VARIANT_GLOB],
|
57 |
+
local_dir=LOCAL_MODEL_DIR,
|
58 |
+
local_dir_use_symlinks=False # Safest for cross-platform/Space compatibility
|
59 |
+
# cache_dir=cache_dir # Optional
|
60 |
+
)
|
61 |
+
model_path = model_variant_dir
|
62 |
+
logging.info(f"Model downloaded to: {model_path}")
|
63 |
+
except Exception as e:
|
64 |
+
logging.error(f"Error downloading model: {e}", exc_info=True)
|
65 |
+
logging.error("Please ensure the Space has internet access and necessary permissions.")
|
66 |
+
logging.error("Check Hugging Face Hub status if issues persist.")
|
67 |
+
# Optionally raise to stop the app, or try to proceed if partial files might exist
|
68 |
+
raise RuntimeError(f"Failed to download model: {e}")
|
69 |
+
|
70 |
+
# --- Load ---
|
71 |
+
logging.info(f"Loading model from: {model_path}")
|
72 |
+
logging.info(f"Using Execution Provider: {EXECUTION_PROVIDER.upper()}")
|
73 |
+
try:
|
74 |
+
og_device_type = getattr(og.DeviceType, EXECUTION_PROVIDER.upper(), og.DeviceType.CPU)
|
75 |
+
model = og.Model(model_path, og_device_type)
|
76 |
+
tokenizer = og.Tokenizer(model)
|
77 |
+
logging.info("Model and Tokenizer loaded successfully.")
|
78 |
+
except Exception as e:
|
79 |
+
logging.error(f"Error loading model or tokenizer: {e}", exc_info=True)
|
80 |
+
logging.error(f"Ensure the correct onnxruntime-genai package is installed (check requirements.txt) for {EXECUTION_PROVIDER}.")
|
81 |
+
logging.error("Verify model files integrity in '{model_path}'.")
|
82 |
+
raise RuntimeError(f"Failed to load model: {e}")
|
83 |
+
|
84 |
+
# --- Generation Function ---
|
85 |
+
def generate_response(prompt, history, max_length=1024, temperature=0.7, top_p=0.9, top_k=50):
|
86 |
+
"""Generates a response using the Phi-4 ONNX model."""
|
87 |
+
if not model or not tokenizer:
|
88 |
+
return "Error: Model not initialized. Please check logs."
|
89 |
+
if not prompt:
|
90 |
+
return "Please enter a prompt."
|
91 |
+
|
92 |
+
# --- Prepare the prompt using the Phi-4 instruct format ---
|
93 |
+
# "<|user|>\n{user_message}<|end|>\n<|assistant|>\n{assistant_message}<|end|>"
|
94 |
+
full_prompt = ""
|
95 |
+
for user_msg, assistant_msg in history:
|
96 |
+
full_prompt += f"<|user|>\n{user_msg}<|end|>\n"
|
97 |
+
if assistant_msg: # Add assistant message only if it exists
|
98 |
+
full_prompt += f"<|assistant|>\n{assistant_msg}<|end|>\n"
|
99 |
+
|
100 |
+
# Add the current user prompt and the trigger for the assistant's response
|
101 |
+
full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
|
102 |
+
|
103 |
+
logging.info(f"Generating response for prompt (last part): ...{prompt[-50:]}")
|
104 |
+
# logging.debug(f"Full Formatted Prompt:\n{full_prompt}") # Use debug level
|
105 |
+
|
106 |
+
try:
|
107 |
+
input_tokens = tokenizer.encode(full_prompt)
|
108 |
+
|
109 |
+
search_options = {
|
110 |
+
"max_length": max_length,
|
111 |
+
"temperature": temperature,
|
112 |
+
"top_p": top_p,
|
113 |
+
"top_k": top_k,
|
114 |
+
"do_sample": True, # Sampling is generally preferred for chat
|
115 |
+
"eos_token_id": tokenizer.eos_token_id, # Important for stopping generation
|
116 |
+
"pad_token_id": tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id, # Use EOS if PAD not explicit
|
117 |
+
}
|
118 |
+
|
119 |
+
params = og.GeneratorParams(model)
|
120 |
+
params.set_search_options(**search_options)
|
121 |
+
params.input_ids = input_tokens
|
122 |
+
|
123 |
+
start_time = time.time()
|
124 |
+
generator = og.Generator(model, params)
|
125 |
+
response_text = ""
|
126 |
+
logging.info("Streaming response...")
|
127 |
+
|
128 |
+
# Simple token streaming - yield partial results for Gradio
|
129 |
+
while not generator.is_done():
|
130 |
+
generator.compute_logits()
|
131 |
+
generator.generate_next_token()
|
132 |
+
next_token = generator.get_next_tokens()[0]
|
133 |
+
# Important: Check for EOS token ID to stop manually if needed
|
134 |
+
if next_token == search_options["eos_token_id"]:
|
135 |
+
break
|
136 |
+
decoded_chunk = tokenizer.decode([next_token])
|
137 |
+
response_text += decoded_chunk
|
138 |
+
yield response_text # Yield intermediate results for streaming effect
|
139 |
+
|
140 |
+
end_time = time.time()
|
141 |
+
logging.info(f"Generation complete. Time taken: {end_time - start_time:.2f} seconds")
|
142 |
+
logging.info(f"Full Response (last 100 chars): ...{response_text[-100:]}")
|
143 |
+
|
144 |
+
# Final yield with the complete text (or return if not using yield in Gradio setup)
|
145 |
+
yield response_text.strip()
|
146 |
+
|
147 |
+
except Exception as e:
|
148 |
+
logging.error(f"Error during generation: {e}", exc_info=True)
|
149 |
+
yield f"Sorry, an error occurred during generation: {e}" # Yield error message
|
150 |
+
|
151 |
+
|
152 |
+
# --- Initialize Model on App Start ---
|
153 |
+
try:
|
154 |
+
initialize_model()
|
155 |
+
except Exception as e:
|
156 |
+
print(f"FATAL: Model initialization failed: {e}")
|
157 |
+
# Optionally create a dummy Gradio interface showing the error
|
158 |
+
# Or just let the script exit/fail in the Space environment
|
159 |
+
|
160 |
+
# --- Gradio Interface ---
|
161 |
+
logging.info("Creating Gradio Interface...")
|
162 |
+
|
163 |
+
with gr.Blocks(theme=gr.themes.Soft()) as demo:
|
164 |
+
gr.Markdown(f"""
|
165 |
+
# Phi-4 Mini Instruct ONNX Demo
|
166 |
+
Powered by [`onnxruntime-genai`](https://github.com/microsoft/onnxruntime-genai) running the **{EXECUTION_PROVIDER.upper()}** INT4 ONNX Runtime version of [`{MODEL_REPO}`](https://huggingface.co/{MODEL_REPO}).
|
167 |
+
Model Variant: `{model_variant_name}`
|
168 |
+
|
169 |
+
<img src="{HF_LOGO_URL}" alt="Hugging Face Logo" style="display: inline-block; height: 1.5em; vertical-align: middle;"> This Space demonstrates running Phi-4 Mini efficiently with ONNX Runtime.
|
170 |
+
""")
|
171 |
+
|
172 |
+
chatbot = gr.Chatbot(label="Chat History", height=500, layout="bubble", bubble_full_width=False)
|
173 |
+
msg = gr.Textbox(
|
174 |
+
label="Your Prompt",
|
175 |
+
placeholder="<|user|>\nType your message here...<|end|>\n<|assistant|>",
|
176 |
+
lines=3,
|
177 |
+
info="Using the recommended Phi-4 instruct format." # Add info text
|
178 |
+
)
|
179 |
+
clear = gr.Button("Clear Chat")
|
180 |
+
|
181 |
+
with gr.Accordion("Generation Parameters", open=False):
|
182 |
+
max_length = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Length", info="Maximum number of tokens to generate.")
|
183 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="Higher values (e.g., 0.8) make output more random, lower values (e.g., 0.2) make it more deterministic.")
|
184 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus Sampling)", info="Filters vocabulary to the smallest set whose cumulative probability exceeds P. Set to 1.0 to disable.")
|
185 |
+
top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Filters vocabulary to the K most likely tokens. Set to 0 to disable.")
|
186 |
+
|
187 |
+
# Use the streaming capability of ChatInterface
|
188 |
+
# msg.submit returns a generator which updates chatbot gradually
|
189 |
+
msg.submit(
|
190 |
+
generate_response,
|
191 |
+
inputs=[msg, chatbot, max_length, temperature, top_p, top_k],
|
192 |
+
outputs=[chatbot]
|
193 |
+
)
|
194 |
+
|
195 |
+
# Connect the clear button
|
196 |
+
clear.click(lambda: (None, None), None, [msg, chatbot], queue=False) # Clear input and chatbot
|
197 |
+
|
198 |
+
gr.Markdown("Enter your prompt using the suggested format and press Enter. Adjust generation parameters in the accordion above.")
|
199 |
+
|
200 |
+
logging.info("Launching Gradio App...")
|
201 |
+
# Setting share=False is default and recommended for Spaces
|
202 |
+
# queue() is important for handling multiple users
|
203 |
+
# debug=True can be useful for local testing but should generally be False in production/Spaces
|
204 |
+
demo.queue()
|
205 |
+
demo.launch(show_error=True) # Show errors in the UI for easier debugging in Spaces
|