Update app.py
Browse files
app.py
CHANGED
@@ -13,33 +13,35 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
|
|
13 |
MODEL_REPO = "microsoft/Phi-4-mini-instruct-onnx"
|
14 |
|
15 |
# --- Defaulting to CPU INT4 for Hugging Face Spaces ---
|
16 |
-
# Free Spaces generally provide CPU resources.
|
17 |
-
# If deploying on a paid GPU Space, you would change these
|
18 |
-
# and the requirements.txt accordingly.
|
19 |
EXECUTION_PROVIDER = "cpu"
|
20 |
MODEL_VARIANT_GLOB = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/*"
|
21 |
# Ensure requirements.txt lists: onnxruntime-genai
|
22 |
# --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
|
23 |
|
24 |
-
# --- Alternative GPU Configuration
|
25 |
# EXECUTION_PROVIDER = "cuda"
|
26 |
# MODEL_VARIANT_GLOB = "gpu/gpu-int4-rtn-block-32/*"
|
27 |
# Ensure requirements.txt lists: onnxruntime-genai-cuda
|
28 |
# --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
|
29 |
|
30 |
LOCAL_MODEL_DIR = "./phi4-mini-onnx-model" # Directory within the Space
|
31 |
-
HF_LOGO_URL = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"
|
|
|
|
|
32 |
|
33 |
# Global variables for model and tokenizer
|
34 |
model = None
|
35 |
tokenizer = None
|
36 |
model_variant_name = os.path.basename(os.path.dirname(MODEL_VARIANT_GLOB)) # For display
|
|
|
37 |
|
38 |
# --- Model Download and Load ---
|
39 |
def initialize_model():
|
40 |
"""Downloads and loads the ONNX model and tokenizer."""
|
41 |
-
global model, tokenizer
|
42 |
logging.info("--- Initializing ONNX Runtime GenAI ---")
|
|
|
|
|
43 |
|
44 |
# --- Download ---
|
45 |
model_variant_dir = os.path.join(LOCAL_MODEL_DIR, os.path.dirname(MODEL_VARIANT_GLOB))
|
@@ -49,59 +51,62 @@ def initialize_model():
|
|
49 |
else:
|
50 |
logging.info(f"Downloading model variant '{MODEL_VARIANT_GLOB}' from {MODEL_REPO}...")
|
51 |
try:
|
52 |
-
# Use cache_dir for potentially faster re-runs if space storage allows caching
|
53 |
-
# cache_dir = os.path.join(LOCAL_MODEL_DIR, ".cache") # Optional: Define cache dir
|
54 |
snapshot_download(
|
55 |
MODEL_REPO,
|
56 |
allow_patterns=[MODEL_VARIANT_GLOB],
|
57 |
local_dir=LOCAL_MODEL_DIR,
|
58 |
-
local_dir_use_symlinks=False
|
59 |
-
# cache_dir=cache_dir # Optional
|
60 |
)
|
61 |
model_path = model_variant_dir
|
62 |
logging.info(f"Model downloaded to: {model_path}")
|
63 |
except Exception as e:
|
64 |
logging.error(f"Error downloading model: {e}", exc_info=True)
|
65 |
-
|
66 |
-
logging.error("Check Hugging Face Hub status if issues persist.")
|
67 |
-
# Optionally raise to stop the app, or try to proceed if partial files might exist
|
68 |
raise RuntimeError(f"Failed to download model: {e}")
|
69 |
|
70 |
# --- Load ---
|
71 |
-
|
72 |
-
logging.info(
|
73 |
try:
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
model = og.Model(model_path, og_device_type)
|
76 |
tokenizer = og.Tokenizer(model)
|
|
|
77 |
logging.info("Model and Tokenizer loaded successfully.")
|
78 |
except Exception as e:
|
79 |
logging.error(f"Error loading model or tokenizer: {e}", exc_info=True)
|
80 |
-
|
81 |
-
logging.error("Verify model files integrity in '{model_path}'.")
|
82 |
raise RuntimeError(f"Failed to load model: {e}")
|
83 |
|
84 |
# --- Generation Function ---
|
85 |
-
def generate_response(prompt, history, max_length
|
86 |
-
"""Generates a response using the Phi-4 ONNX model."""
|
|
|
87 |
if not model or not tokenizer:
|
88 |
-
|
|
|
|
|
89 |
if not prompt:
|
90 |
-
|
|
|
91 |
|
92 |
# --- Prepare the prompt using the Phi-4 instruct format ---
|
93 |
-
# "<|user|>\n{user_message}<|end|>\n<|assistant|>\n{assistant_message}<|end|>"
|
94 |
full_prompt = ""
|
95 |
for user_msg, assistant_msg in history:
|
96 |
full_prompt += f"<|user|>\n{user_msg}<|end|>\n"
|
97 |
-
if assistant_msg:
|
98 |
full_prompt += f"<|assistant|>\n{assistant_msg}<|end|>\n"
|
99 |
-
|
100 |
-
# Add the current user prompt and the trigger for the assistant's response
|
101 |
full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
|
102 |
|
103 |
-
logging.info(f"Generating response
|
104 |
-
# logging.debug(f"Full Formatted Prompt:\n{full_prompt}")
|
105 |
|
106 |
try:
|
107 |
input_tokens = tokenizer.encode(full_prompt)
|
@@ -111,9 +116,9 @@ def generate_response(prompt, history, max_length=1024, temperature=0.7, top_p=0
|
|
111 |
"temperature": temperature,
|
112 |
"top_p": top_p,
|
113 |
"top_k": top_k,
|
114 |
-
"do_sample": True,
|
115 |
-
"eos_token_id": tokenizer.eos_token_id,
|
116 |
-
"pad_token_id": tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
|
117 |
}
|
118 |
|
119 |
params = og.GeneratorParams(model)
|
@@ -123,83 +128,155 @@ def generate_response(prompt, history, max_length=1024, temperature=0.7, top_p=0
|
|
123 |
start_time = time.time()
|
124 |
generator = og.Generator(model, params)
|
125 |
response_text = ""
|
|
|
126 |
logging.info("Streaming response...")
|
127 |
|
128 |
-
|
129 |
while not generator.is_done():
|
130 |
generator.compute_logits()
|
131 |
generator.generate_next_token()
|
|
|
|
|
132 |
next_token = generator.get_next_tokens()[0]
|
133 |
-
|
134 |
if next_token == search_options["eos_token_id"]:
|
|
|
135 |
break
|
|
|
136 |
decoded_chunk = tokenizer.decode([next_token])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
137 |
response_text += decoded_chunk
|
138 |
yield response_text # Yield intermediate results for streaming effect
|
139 |
|
140 |
end_time = time.time()
|
141 |
-
|
142 |
-
|
|
|
|
|
143 |
|
144 |
-
|
|
|
|
|
|
|
145 |
yield response_text.strip()
|
146 |
|
147 |
except Exception as e:
|
148 |
logging.error(f"Error during generation: {e}", exc_info=True)
|
149 |
-
|
|
|
150 |
|
|
|
|
|
|
|
151 |
|
152 |
# --- Initialize Model on App Start ---
|
|
|
153 |
try:
|
154 |
initialize_model()
|
155 |
except Exception as e:
|
156 |
print(f"FATAL: Model initialization failed: {e}")
|
157 |
-
|
158 |
-
#
|
159 |
|
160 |
# --- Gradio Interface ---
|
161 |
logging.info("Creating Gradio Interface...")
|
162 |
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
)
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
189 |
-
|
190 |
-
|
191 |
-
|
192 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
193 |
)
|
194 |
|
195 |
-
#
|
196 |
-
|
197 |
-
|
198 |
-
|
|
|
|
|
|
|
199 |
|
|
|
200 |
logging.info("Launching Gradio App...")
|
201 |
-
#
|
202 |
-
|
203 |
-
# debug=True can be useful for local testing but should generally be False in production/Spaces
|
204 |
-
demo.queue()
|
205 |
-
demo.launch(show_error=True) # Show errors in the UI for easier debugging in Spaces
|
|
|
13 |
MODEL_REPO = "microsoft/Phi-4-mini-instruct-onnx"
|
14 |
|
15 |
# --- Defaulting to CPU INT4 for Hugging Face Spaces ---
|
|
|
|
|
|
|
16 |
EXECUTION_PROVIDER = "cpu"
|
17 |
MODEL_VARIANT_GLOB = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/*"
|
18 |
# Ensure requirements.txt lists: onnxruntime-genai
|
19 |
# --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
|
20 |
|
21 |
+
# --- (Optional) Alternative GPU Configuration ---
|
22 |
# EXECUTION_PROVIDER = "cuda"
|
23 |
# MODEL_VARIANT_GLOB = "gpu/gpu-int4-rtn-block-32/*"
|
24 |
# Ensure requirements.txt lists: onnxruntime-genai-cuda
|
25 |
# --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
|
26 |
|
27 |
LOCAL_MODEL_DIR = "./phi4-mini-onnx-model" # Directory within the Space
|
28 |
+
HF_LOGO_URL = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"
|
29 |
+
HF_MODEL_URL = f"https://huggingface.co/{MODEL_REPO}"
|
30 |
+
ORT_GENAI_URL = "https://github.com/microsoft/onnxruntime-genai"
|
31 |
|
32 |
# Global variables for model and tokenizer
|
33 |
model = None
|
34 |
tokenizer = None
|
35 |
model_variant_name = os.path.basename(os.path.dirname(MODEL_VARIANT_GLOB)) # For display
|
36 |
+
model_status = "Initializing..."
|
37 |
|
38 |
# --- Model Download and Load ---
|
39 |
def initialize_model():
|
40 |
"""Downloads and loads the ONNX model and tokenizer."""
|
41 |
+
global model, tokenizer, model_status
|
42 |
logging.info("--- Initializing ONNX Runtime GenAI ---")
|
43 |
+
model_status = "Downloading model..."
|
44 |
+
logging.info(model_status)
|
45 |
|
46 |
# --- Download ---
|
47 |
model_variant_dir = os.path.join(LOCAL_MODEL_DIR, os.path.dirname(MODEL_VARIANT_GLOB))
|
|
|
51 |
else:
|
52 |
logging.info(f"Downloading model variant '{MODEL_VARIANT_GLOB}' from {MODEL_REPO}...")
|
53 |
try:
|
|
|
|
|
54 |
snapshot_download(
|
55 |
MODEL_REPO,
|
56 |
allow_patterns=[MODEL_VARIANT_GLOB],
|
57 |
local_dir=LOCAL_MODEL_DIR,
|
58 |
+
local_dir_use_symlinks=False
|
|
|
59 |
)
|
60 |
model_path = model_variant_dir
|
61 |
logging.info(f"Model downloaded to: {model_path}")
|
62 |
except Exception as e:
|
63 |
logging.error(f"Error downloading model: {e}", exc_info=True)
|
64 |
+
model_status = f"Error downloading model: {e}"
|
|
|
|
|
65 |
raise RuntimeError(f"Failed to download model: {e}")
|
66 |
|
67 |
# --- Load ---
|
68 |
+
model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..."
|
69 |
+
logging.info(model_status)
|
70 |
try:
|
71 |
+
# Determine device type based on execution provider string
|
72 |
+
if EXECUTION_PROVIDER.lower() == "cuda":
|
73 |
+
og_device_type = og.DeviceType.CUDA
|
74 |
+
elif EXECUTION_PROVIDER.lower() == "dml":
|
75 |
+
og_device_type = og.DeviceType.DML # Requires onnxruntime-genai-directml
|
76 |
+
else: # Default to CPU
|
77 |
+
og_device_type = og.DeviceType.CPU
|
78 |
+
|
79 |
model = og.Model(model_path, og_device_type)
|
80 |
tokenizer = og.Tokenizer(model)
|
81 |
+
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})"
|
82 |
logging.info("Model and Tokenizer loaded successfully.")
|
83 |
except Exception as e:
|
84 |
logging.error(f"Error loading model or tokenizer: {e}", exc_info=True)
|
85 |
+
model_status = f"Error loading model: {e}"
|
|
|
86 |
raise RuntimeError(f"Failed to load model: {e}")
|
87 |
|
88 |
# --- Generation Function ---
|
89 |
+
def generate_response(prompt, history, max_length, temperature, top_p, top_k):
|
90 |
+
"""Generates a response using the Phi-4 ONNX model, yielding partial results."""
|
91 |
+
global model_status
|
92 |
if not model or not tokenizer:
|
93 |
+
model_status = "Error: Model not initialized!"
|
94 |
+
yield "Error: Model not initialized. Please check logs."
|
95 |
+
return
|
96 |
if not prompt:
|
97 |
+
yield "Please enter a prompt."
|
98 |
+
return
|
99 |
|
100 |
# --- Prepare the prompt using the Phi-4 instruct format ---
|
|
|
101 |
full_prompt = ""
|
102 |
for user_msg, assistant_msg in history:
|
103 |
full_prompt += f"<|user|>\n{user_msg}<|end|>\n"
|
104 |
+
if assistant_msg:
|
105 |
full_prompt += f"<|assistant|>\n{assistant_msg}<|end|>\n"
|
|
|
|
|
106 |
full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
|
107 |
|
108 |
+
logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})")
|
109 |
+
# logging.debug(f"Full Formatted Prompt:\n{full_prompt}")
|
110 |
|
111 |
try:
|
112 |
input_tokens = tokenizer.encode(full_prompt)
|
|
|
116 |
"temperature": temperature,
|
117 |
"top_p": top_p,
|
118 |
"top_k": top_k,
|
119 |
+
"do_sample": True,
|
120 |
+
"eos_token_id": tokenizer.eos_token_id,
|
121 |
+
"pad_token_id": tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
|
122 |
}
|
123 |
|
124 |
params = og.GeneratorParams(model)
|
|
|
128 |
start_time = time.time()
|
129 |
generator = og.Generator(model, params)
|
130 |
response_text = ""
|
131 |
+
model_status = "Generating..." # Update status indicator
|
132 |
logging.info("Streaming response...")
|
133 |
|
134 |
+
first_token_time = None
|
135 |
while not generator.is_done():
|
136 |
generator.compute_logits()
|
137 |
generator.generate_next_token()
|
138 |
+
if first_token_time is None:
|
139 |
+
first_token_time = time.time() # Record time to first token
|
140 |
next_token = generator.get_next_tokens()[0]
|
141 |
+
|
142 |
if next_token == search_options["eos_token_id"]:
|
143 |
+
logging.info("EOS token encountered.")
|
144 |
break
|
145 |
+
|
146 |
decoded_chunk = tokenizer.decode([next_token])
|
147 |
+
|
148 |
+
# Handle potential decoding issues or special tokens if necessary
|
149 |
+
# (e.g., some models might output "<|end|>" which you might want to strip)
|
150 |
+
if decoded_chunk == "<|end|>": # Example: Stop if assistant outputs end token explicitly
|
151 |
+
logging.info("Assistant explicitly generated <|end|> token.")
|
152 |
+
break
|
153 |
+
|
154 |
response_text += decoded_chunk
|
155 |
yield response_text # Yield intermediate results for streaming effect
|
156 |
|
157 |
end_time = time.time()
|
158 |
+
ttft = (first_token_time - start_time) * 1000 if first_token_time else -1
|
159 |
+
total_time = end_time - start_time
|
160 |
+
token_count = len(tokenizer.decode(generator.get_output_sequences()[0])) # Approx token count
|
161 |
+
tps = (token_count / total_time) if total_time > 0 else 0
|
162 |
|
163 |
+
logging.info(f"Generation complete. Tokens: ~{token_count}, Total Time: {total_time:.2f}s, TTFT: {ttft:.2f}ms, TPS: {tps:.2f}")
|
164 |
+
model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" # Reset status
|
165 |
+
|
166 |
+
# Final yield with the complete text
|
167 |
yield response_text.strip()
|
168 |
|
169 |
except Exception as e:
|
170 |
logging.error(f"Error during generation: {e}", exc_info=True)
|
171 |
+
model_status = f"Error during generation: {e}"
|
172 |
+
yield f"Sorry, an error occurred during generation: {e}"
|
173 |
|
174 |
+
# --- Clear Chat Function ---
|
175 |
+
def clear_chat():
|
176 |
+
return None, None # Clears Textbox and Chatbot
|
177 |
|
178 |
# --- Initialize Model on App Start ---
|
179 |
+
# Wrap in try-except to allow Gradio UI to potentially load even if model fails
|
180 |
try:
|
181 |
initialize_model()
|
182 |
except Exception as e:
|
183 |
print(f"FATAL: Model initialization failed: {e}")
|
184 |
+
model_status = f"FATAL ERROR during init: {e}"
|
185 |
+
# The UI will still load, but generation will fail. The status will show the error.
|
186 |
|
187 |
# --- Gradio Interface ---
|
188 |
logging.info("Creating Gradio Interface...")
|
189 |
|
190 |
+
# Select a theme
|
191 |
+
theme = gr.themes.Soft(
|
192 |
+
primary_hue="blue",
|
193 |
+
secondary_hue="sky",
|
194 |
+
neutral_hue="slate",
|
195 |
+
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
|
196 |
+
).set(
|
197 |
+
# Customize specific component styles if needed
|
198 |
+
# button_primary_background_fill="*primary_500",
|
199 |
+
# button_primary_background_fill_hover="*primary_400",
|
200 |
+
)
|
201 |
+
|
202 |
+
with gr.Blocks(theme=theme, title="Phi-4 Mini ONNX Chat") as demo:
|
203 |
+
# Header Section
|
204 |
+
with gr.Row(equal_height=False):
|
205 |
+
with gr.Column(scale=3):
|
206 |
+
gr.Markdown(f"""
|
207 |
+
# Phi-4 Mini Instruct ONNX Chat 🤖
|
208 |
+
Interact with the quantized `{model_variant_name}` version of [`{MODEL_REPO}`]({HF_MODEL_URL})
|
209 |
+
running efficiently via [`onnxruntime-genai`]({ORT_GENAI_URL}).
|
210 |
+
""")
|
211 |
+
with gr.Column(scale=1, min_width=150):
|
212 |
+
gr.Image(HF_LOGO_URL, elem_id="hf-logo", show_label=False, show_download_button=False, container=False, height=50)
|
213 |
+
model_status_text = gr.Textbox(value=model_status, label="Model Status", interactive=False, max_lines=2)
|
214 |
+
|
215 |
+
|
216 |
+
# Main Layout (Chat on Left, Settings on Right)
|
217 |
+
with gr.Row():
|
218 |
+
# Chat Column
|
219 |
+
with gr.Column(scale=3):
|
220 |
+
chatbot = gr.Chatbot(
|
221 |
+
label="Conversation",
|
222 |
+
height=600,
|
223 |
+
layout="bubble",
|
224 |
+
bubble_full_width=False,
|
225 |
+
avatar_images=(None, "https://microsoft.github.io/phi/assets/img/logo-final.png") # (user, bot) - Optional: Add user avatar path/URL if desired
|
226 |
+
)
|
227 |
+
with gr.Row():
|
228 |
+
prompt_input = gr.Textbox(
|
229 |
+
label="Your Message",
|
230 |
+
placeholder="<|user|>\nType your message here...\n<|end|>",
|
231 |
+
lines=4,
|
232 |
+
scale=9 # Make textbox wider
|
233 |
+
)
|
234 |
+
submit_button = gr.Button("Send", variant="primary", scale=1, min_width=120) # Primary send button
|
235 |
+
clear_button = gr.Button("🗑️ Clear", variant="secondary", scale=1, min_width=120) # Secondary clear button
|
236 |
+
|
237 |
+
|
238 |
+
# Settings Column
|
239 |
+
with gr.Column(scale=1, min_width=250):
|
240 |
+
gr.Markdown("### ⚙️ Generation Settings")
|
241 |
+
with gr.Group(): # Group settings visually
|
242 |
+
max_length = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Length", info="Max tokens in response.")
|
243 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="0.0 = deterministic\n>1.0 = more random")
|
244 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling probability.")
|
245 |
+
top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Limit to K most likely tokens (0=disable).")
|
246 |
+
|
247 |
+
gr.Markdown("---") # Separator
|
248 |
+
gr.Markdown("ℹ️ **Note:** Uses Phi-4 instruction format: \n`<|user|>\nPROMPT<|end|>\n<|assistant|>`")
|
249 |
+
|
250 |
+
|
251 |
+
# Event Listeners (Connecting UI components to functions)
|
252 |
+
|
253 |
+
# Define reusable inputs list for generation
|
254 |
+
gen_inputs = [prompt_input, chatbot, max_length, temperature, top_p, top_k]
|
255 |
+
|
256 |
+
# Submit action (using streaming yields from generate_response)
|
257 |
+
submit_button.click(
|
258 |
+
fn=generate_response,
|
259 |
+
inputs=gen_inputs,
|
260 |
+
outputs=[chatbot], # Output directly streams to chatbot
|
261 |
+
queue=True # Enable queuing
|
262 |
+
)
|
263 |
+
# Allow submitting via Enter key in the textbox as well
|
264 |
+
prompt_input.submit(
|
265 |
+
fn=generate_response,
|
266 |
+
inputs=gen_inputs,
|
267 |
+
outputs=[chatbot],
|
268 |
+
queue=True
|
269 |
)
|
270 |
|
271 |
+
# Clear button action
|
272 |
+
clear_button.click(
|
273 |
+
fn=clear_chat,
|
274 |
+
inputs=None,
|
275 |
+
outputs=[prompt_input, chatbot], # Clear both input and chat history
|
276 |
+
queue=False # No need to queue clearing
|
277 |
+
)
|
278 |
|
279 |
+
# Launch the Gradio app
|
280 |
logging.info("Launching Gradio App...")
|
281 |
+
demo.queue() # Enable queuing for handling concurrent users/requests
|
282 |
+
demo.launch(show_error=True, max_threads=40) # show_error=True helps debug in Spaces
|
|
|
|
|
|