Athspi commited on
Commit
513f7a6
·
verified ·
1 Parent(s): 7fd9f7d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -78
app.py CHANGED
@@ -13,33 +13,35 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
13
  MODEL_REPO = "microsoft/Phi-4-mini-instruct-onnx"
14
 
15
  # --- Defaulting to CPU INT4 for Hugging Face Spaces ---
16
- # Free Spaces generally provide CPU resources.
17
- # If deploying on a paid GPU Space, you would change these
18
- # and the requirements.txt accordingly.
19
  EXECUTION_PROVIDER = "cpu"
20
  MODEL_VARIANT_GLOB = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/*"
21
  # Ensure requirements.txt lists: onnxruntime-genai
22
  # --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
23
 
24
- # --- Alternative GPU Configuration (Requires GPU Space & requirements change) ---
25
  # EXECUTION_PROVIDER = "cuda"
26
  # MODEL_VARIANT_GLOB = "gpu/gpu-int4-rtn-block-32/*"
27
  # Ensure requirements.txt lists: onnxruntime-genai-cuda
28
  # --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
29
 
30
  LOCAL_MODEL_DIR = "./phi4-mini-onnx-model" # Directory within the Space
31
- HF_LOGO_URL = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg" # Official HF Logo
 
 
32
 
33
  # Global variables for model and tokenizer
34
  model = None
35
  tokenizer = None
36
  model_variant_name = os.path.basename(os.path.dirname(MODEL_VARIANT_GLOB)) # For display
 
37
 
38
  # --- Model Download and Load ---
39
  def initialize_model():
40
  """Downloads and loads the ONNX model and tokenizer."""
41
- global model, tokenizer
42
  logging.info("--- Initializing ONNX Runtime GenAI ---")
 
 
43
 
44
  # --- Download ---
45
  model_variant_dir = os.path.join(LOCAL_MODEL_DIR, os.path.dirname(MODEL_VARIANT_GLOB))
@@ -49,59 +51,62 @@ def initialize_model():
49
  else:
50
  logging.info(f"Downloading model variant '{MODEL_VARIANT_GLOB}' from {MODEL_REPO}...")
51
  try:
52
- # Use cache_dir for potentially faster re-runs if space storage allows caching
53
- # cache_dir = os.path.join(LOCAL_MODEL_DIR, ".cache") # Optional: Define cache dir
54
  snapshot_download(
55
  MODEL_REPO,
56
  allow_patterns=[MODEL_VARIANT_GLOB],
57
  local_dir=LOCAL_MODEL_DIR,
58
- local_dir_use_symlinks=False # Safest for cross-platform/Space compatibility
59
- # cache_dir=cache_dir # Optional
60
  )
61
  model_path = model_variant_dir
62
  logging.info(f"Model downloaded to: {model_path}")
63
  except Exception as e:
64
  logging.error(f"Error downloading model: {e}", exc_info=True)
65
- logging.error("Please ensure the Space has internet access and necessary permissions.")
66
- logging.error("Check Hugging Face Hub status if issues persist.")
67
- # Optionally raise to stop the app, or try to proceed if partial files might exist
68
  raise RuntimeError(f"Failed to download model: {e}")
69
 
70
  # --- Load ---
71
- logging.info(f"Loading model from: {model_path}")
72
- logging.info(f"Using Execution Provider: {EXECUTION_PROVIDER.upper()}")
73
  try:
74
- og_device_type = getattr(og.DeviceType, EXECUTION_PROVIDER.upper(), og.DeviceType.CPU)
 
 
 
 
 
 
 
75
  model = og.Model(model_path, og_device_type)
76
  tokenizer = og.Tokenizer(model)
 
77
  logging.info("Model and Tokenizer loaded successfully.")
78
  except Exception as e:
79
  logging.error(f"Error loading model or tokenizer: {e}", exc_info=True)
80
- logging.error(f"Ensure the correct onnxruntime-genai package is installed (check requirements.txt) for {EXECUTION_PROVIDER}.")
81
- logging.error("Verify model files integrity in '{model_path}'.")
82
  raise RuntimeError(f"Failed to load model: {e}")
83
 
84
  # --- Generation Function ---
85
- def generate_response(prompt, history, max_length=1024, temperature=0.7, top_p=0.9, top_k=50):
86
- """Generates a response using the Phi-4 ONNX model."""
 
87
  if not model or not tokenizer:
88
- return "Error: Model not initialized. Please check logs."
 
 
89
  if not prompt:
90
- return "Please enter a prompt."
 
91
 
92
  # --- Prepare the prompt using the Phi-4 instruct format ---
93
- # "<|user|>\n{user_message}<|end|>\n<|assistant|>\n{assistant_message}<|end|>"
94
  full_prompt = ""
95
  for user_msg, assistant_msg in history:
96
  full_prompt += f"<|user|>\n{user_msg}<|end|>\n"
97
- if assistant_msg: # Add assistant message only if it exists
98
  full_prompt += f"<|assistant|>\n{assistant_msg}<|end|>\n"
99
-
100
- # Add the current user prompt and the trigger for the assistant's response
101
  full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
102
 
103
- logging.info(f"Generating response for prompt (last part): ...{prompt[-50:]}")
104
- # logging.debug(f"Full Formatted Prompt:\n{full_prompt}") # Use debug level
105
 
106
  try:
107
  input_tokens = tokenizer.encode(full_prompt)
@@ -111,9 +116,9 @@ def generate_response(prompt, history, max_length=1024, temperature=0.7, top_p=0
111
  "temperature": temperature,
112
  "top_p": top_p,
113
  "top_k": top_k,
114
- "do_sample": True, # Sampling is generally preferred for chat
115
- "eos_token_id": tokenizer.eos_token_id, # Important for stopping generation
116
- "pad_token_id": tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id, # Use EOS if PAD not explicit
117
  }
118
 
119
  params = og.GeneratorParams(model)
@@ -123,83 +128,155 @@ def generate_response(prompt, history, max_length=1024, temperature=0.7, top_p=0
123
  start_time = time.time()
124
  generator = og.Generator(model, params)
125
  response_text = ""
 
126
  logging.info("Streaming response...")
127
 
128
- # Simple token streaming - yield partial results for Gradio
129
  while not generator.is_done():
130
  generator.compute_logits()
131
  generator.generate_next_token()
 
 
132
  next_token = generator.get_next_tokens()[0]
133
- # Important: Check for EOS token ID to stop manually if needed
134
  if next_token == search_options["eos_token_id"]:
 
135
  break
 
136
  decoded_chunk = tokenizer.decode([next_token])
 
 
 
 
 
 
 
137
  response_text += decoded_chunk
138
  yield response_text # Yield intermediate results for streaming effect
139
 
140
  end_time = time.time()
141
- logging.info(f"Generation complete. Time taken: {end_time - start_time:.2f} seconds")
142
- logging.info(f"Full Response (last 100 chars): ...{response_text[-100:]}")
 
 
143
 
144
- # Final yield with the complete text (or return if not using yield in Gradio setup)
 
 
 
145
  yield response_text.strip()
146
 
147
  except Exception as e:
148
  logging.error(f"Error during generation: {e}", exc_info=True)
149
- yield f"Sorry, an error occurred during generation: {e}" # Yield error message
 
150
 
 
 
 
151
 
152
  # --- Initialize Model on App Start ---
 
153
  try:
154
  initialize_model()
155
  except Exception as e:
156
  print(f"FATAL: Model initialization failed: {e}")
157
- # Optionally create a dummy Gradio interface showing the error
158
- # Or just let the script exit/fail in the Space environment
159
 
160
  # --- Gradio Interface ---
161
  logging.info("Creating Gradio Interface...")
162
 
163
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
164
- gr.Markdown(f"""
165
- # Phi-4 Mini Instruct ONNX Demo
166
- Powered by [`onnxruntime-genai`](https://github.com/microsoft/onnxruntime-genai) running the **{EXECUTION_PROVIDER.upper()}** INT4 ONNX Runtime version of [`{MODEL_REPO}`](https://huggingface.co/{MODEL_REPO}).
167
- Model Variant: `{model_variant_name}`
168
-
169
- <img src="{HF_LOGO_URL}" alt="Hugging Face Logo" style="display: inline-block; height: 1.5em; vertical-align: middle;"> This Space demonstrates running Phi-4 Mini efficiently with ONNX Runtime.
170
- """)
171
-
172
- chatbot = gr.Chatbot(label="Chat History", height=500, layout="bubble", bubble_full_width=False)
173
- msg = gr.Textbox(
174
- label="Your Prompt",
175
- placeholder="<|user|>\nType your message here...<|end|>\n<|assistant|>",
176
- lines=3,
177
- info="Using the recommended Phi-4 instruct format." # Add info text
178
- )
179
- clear = gr.Button("Clear Chat")
180
-
181
- with gr.Accordion("Generation Parameters", open=False):
182
- max_length = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Length", info="Maximum number of tokens to generate.")
183
- temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="Higher values (e.g., 0.8) make output more random, lower values (e.g., 0.2) make it more deterministic.")
184
- top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus Sampling)", info="Filters vocabulary to the smallest set whose cumulative probability exceeds P. Set to 1.0 to disable.")
185
- top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Filters vocabulary to the K most likely tokens. Set to 0 to disable.")
186
-
187
- # Use the streaming capability of ChatInterface
188
- # msg.submit returns a generator which updates chatbot gradually
189
- msg.submit(
190
- generate_response,
191
- inputs=[msg, chatbot, max_length, temperature, top_p, top_k],
192
- outputs=[chatbot]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  )
194
 
195
- # Connect the clear button
196
- clear.click(lambda: (None, None), None, [msg, chatbot], queue=False) # Clear input and chatbot
197
-
198
- gr.Markdown("Enter your prompt using the suggested format and press Enter. Adjust generation parameters in the accordion above.")
 
 
 
199
 
 
200
  logging.info("Launching Gradio App...")
201
- # Setting share=False is default and recommended for Spaces
202
- # queue() is important for handling multiple users
203
- # debug=True can be useful for local testing but should generally be False in production/Spaces
204
- demo.queue()
205
- demo.launch(show_error=True) # Show errors in the UI for easier debugging in Spaces
 
13
  MODEL_REPO = "microsoft/Phi-4-mini-instruct-onnx"
14
 
15
  # --- Defaulting to CPU INT4 for Hugging Face Spaces ---
 
 
 
16
  EXECUTION_PROVIDER = "cpu"
17
  MODEL_VARIANT_GLOB = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/*"
18
  # Ensure requirements.txt lists: onnxruntime-genai
19
  # --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
20
 
21
+ # --- (Optional) Alternative GPU Configuration ---
22
  # EXECUTION_PROVIDER = "cuda"
23
  # MODEL_VARIANT_GLOB = "gpu/gpu-int4-rtn-block-32/*"
24
  # Ensure requirements.txt lists: onnxruntime-genai-cuda
25
  # --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
26
 
27
  LOCAL_MODEL_DIR = "./phi4-mini-onnx-model" # Directory within the Space
28
+ HF_LOGO_URL = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg"
29
+ HF_MODEL_URL = f"https://huggingface.co/{MODEL_REPO}"
30
+ ORT_GENAI_URL = "https://github.com/microsoft/onnxruntime-genai"
31
 
32
  # Global variables for model and tokenizer
33
  model = None
34
  tokenizer = None
35
  model_variant_name = os.path.basename(os.path.dirname(MODEL_VARIANT_GLOB)) # For display
36
+ model_status = "Initializing..."
37
 
38
  # --- Model Download and Load ---
39
  def initialize_model():
40
  """Downloads and loads the ONNX model and tokenizer."""
41
+ global model, tokenizer, model_status
42
  logging.info("--- Initializing ONNX Runtime GenAI ---")
43
+ model_status = "Downloading model..."
44
+ logging.info(model_status)
45
 
46
  # --- Download ---
47
  model_variant_dir = os.path.join(LOCAL_MODEL_DIR, os.path.dirname(MODEL_VARIANT_GLOB))
 
51
  else:
52
  logging.info(f"Downloading model variant '{MODEL_VARIANT_GLOB}' from {MODEL_REPO}...")
53
  try:
 
 
54
  snapshot_download(
55
  MODEL_REPO,
56
  allow_patterns=[MODEL_VARIANT_GLOB],
57
  local_dir=LOCAL_MODEL_DIR,
58
+ local_dir_use_symlinks=False
 
59
  )
60
  model_path = model_variant_dir
61
  logging.info(f"Model downloaded to: {model_path}")
62
  except Exception as e:
63
  logging.error(f"Error downloading model: {e}", exc_info=True)
64
+ model_status = f"Error downloading model: {e}"
 
 
65
  raise RuntimeError(f"Failed to download model: {e}")
66
 
67
  # --- Load ---
68
+ model_status = f"Loading model ({EXECUTION_PROVIDER.upper()})..."
69
+ logging.info(model_status)
70
  try:
71
+ # Determine device type based on execution provider string
72
+ if EXECUTION_PROVIDER.lower() == "cuda":
73
+ og_device_type = og.DeviceType.CUDA
74
+ elif EXECUTION_PROVIDER.lower() == "dml":
75
+ og_device_type = og.DeviceType.DML # Requires onnxruntime-genai-directml
76
+ else: # Default to CPU
77
+ og_device_type = og.DeviceType.CPU
78
+
79
  model = og.Model(model_path, og_device_type)
80
  tokenizer = og.Tokenizer(model)
81
+ model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})"
82
  logging.info("Model and Tokenizer loaded successfully.")
83
  except Exception as e:
84
  logging.error(f"Error loading model or tokenizer: {e}", exc_info=True)
85
+ model_status = f"Error loading model: {e}"
 
86
  raise RuntimeError(f"Failed to load model: {e}")
87
 
88
  # --- Generation Function ---
89
+ def generate_response(prompt, history, max_length, temperature, top_p, top_k):
90
+ """Generates a response using the Phi-4 ONNX model, yielding partial results."""
91
+ global model_status
92
  if not model or not tokenizer:
93
+ model_status = "Error: Model not initialized!"
94
+ yield "Error: Model not initialized. Please check logs."
95
+ return
96
  if not prompt:
97
+ yield "Please enter a prompt."
98
+ return
99
 
100
  # --- Prepare the prompt using the Phi-4 instruct format ---
 
101
  full_prompt = ""
102
  for user_msg, assistant_msg in history:
103
  full_prompt += f"<|user|>\n{user_msg}<|end|>\n"
104
+ if assistant_msg:
105
  full_prompt += f"<|assistant|>\n{assistant_msg}<|end|>\n"
 
 
106
  full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
107
 
108
+ logging.info(f"Generating response (MaxL: {max_length}, Temp: {temperature}, TopP: {top_p}, TopK: {top_k})")
109
+ # logging.debug(f"Full Formatted Prompt:\n{full_prompt}")
110
 
111
  try:
112
  input_tokens = tokenizer.encode(full_prompt)
 
116
  "temperature": temperature,
117
  "top_p": top_p,
118
  "top_k": top_k,
119
+ "do_sample": True,
120
+ "eos_token_id": tokenizer.eos_token_id,
121
+ "pad_token_id": tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id,
122
  }
123
 
124
  params = og.GeneratorParams(model)
 
128
  start_time = time.time()
129
  generator = og.Generator(model, params)
130
  response_text = ""
131
+ model_status = "Generating..." # Update status indicator
132
  logging.info("Streaming response...")
133
 
134
+ first_token_time = None
135
  while not generator.is_done():
136
  generator.compute_logits()
137
  generator.generate_next_token()
138
+ if first_token_time is None:
139
+ first_token_time = time.time() # Record time to first token
140
  next_token = generator.get_next_tokens()[0]
141
+
142
  if next_token == search_options["eos_token_id"]:
143
+ logging.info("EOS token encountered.")
144
  break
145
+
146
  decoded_chunk = tokenizer.decode([next_token])
147
+
148
+ # Handle potential decoding issues or special tokens if necessary
149
+ # (e.g., some models might output "<|end|>" which you might want to strip)
150
+ if decoded_chunk == "<|end|>": # Example: Stop if assistant outputs end token explicitly
151
+ logging.info("Assistant explicitly generated <|end|> token.")
152
+ break
153
+
154
  response_text += decoded_chunk
155
  yield response_text # Yield intermediate results for streaming effect
156
 
157
  end_time = time.time()
158
+ ttft = (first_token_time - start_time) * 1000 if first_token_time else -1
159
+ total_time = end_time - start_time
160
+ token_count = len(tokenizer.decode(generator.get_output_sequences()[0])) # Approx token count
161
+ tps = (token_count / total_time) if total_time > 0 else 0
162
 
163
+ logging.info(f"Generation complete. Tokens: ~{token_count}, Total Time: {total_time:.2f}s, TTFT: {ttft:.2f}ms, TPS: {tps:.2f}")
164
+ model_status = f"Model Ready ({EXECUTION_PROVIDER.upper()} / {model_variant_name})" # Reset status
165
+
166
+ # Final yield with the complete text
167
  yield response_text.strip()
168
 
169
  except Exception as e:
170
  logging.error(f"Error during generation: {e}", exc_info=True)
171
+ model_status = f"Error during generation: {e}"
172
+ yield f"Sorry, an error occurred during generation: {e}"
173
 
174
+ # --- Clear Chat Function ---
175
+ def clear_chat():
176
+ return None, None # Clears Textbox and Chatbot
177
 
178
  # --- Initialize Model on App Start ---
179
+ # Wrap in try-except to allow Gradio UI to potentially load even if model fails
180
  try:
181
  initialize_model()
182
  except Exception as e:
183
  print(f"FATAL: Model initialization failed: {e}")
184
+ model_status = f"FATAL ERROR during init: {e}"
185
+ # The UI will still load, but generation will fail. The status will show the error.
186
 
187
  # --- Gradio Interface ---
188
  logging.info("Creating Gradio Interface...")
189
 
190
+ # Select a theme
191
+ theme = gr.themes.Soft(
192
+ primary_hue="blue",
193
+ secondary_hue="sky",
194
+ neutral_hue="slate",
195
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
196
+ ).set(
197
+ # Customize specific component styles if needed
198
+ # button_primary_background_fill="*primary_500",
199
+ # button_primary_background_fill_hover="*primary_400",
200
+ )
201
+
202
+ with gr.Blocks(theme=theme, title="Phi-4 Mini ONNX Chat") as demo:
203
+ # Header Section
204
+ with gr.Row(equal_height=False):
205
+ with gr.Column(scale=3):
206
+ gr.Markdown(f"""
207
+ # Phi-4 Mini Instruct ONNX Chat 🤖
208
+ Interact with the quantized `{model_variant_name}` version of [`{MODEL_REPO}`]({HF_MODEL_URL})
209
+ running efficiently via [`onnxruntime-genai`]({ORT_GENAI_URL}).
210
+ """)
211
+ with gr.Column(scale=1, min_width=150):
212
+ gr.Image(HF_LOGO_URL, elem_id="hf-logo", show_label=False, show_download_button=False, container=False, height=50)
213
+ model_status_text = gr.Textbox(value=model_status, label="Model Status", interactive=False, max_lines=2)
214
+
215
+
216
+ # Main Layout (Chat on Left, Settings on Right)
217
+ with gr.Row():
218
+ # Chat Column
219
+ with gr.Column(scale=3):
220
+ chatbot = gr.Chatbot(
221
+ label="Conversation",
222
+ height=600,
223
+ layout="bubble",
224
+ bubble_full_width=False,
225
+ avatar_images=(None, "https://microsoft.github.io/phi/assets/img/logo-final.png") # (user, bot) - Optional: Add user avatar path/URL if desired
226
+ )
227
+ with gr.Row():
228
+ prompt_input = gr.Textbox(
229
+ label="Your Message",
230
+ placeholder="<|user|>\nType your message here...\n<|end|>",
231
+ lines=4,
232
+ scale=9 # Make textbox wider
233
+ )
234
+ submit_button = gr.Button("Send", variant="primary", scale=1, min_width=120) # Primary send button
235
+ clear_button = gr.Button("🗑️ Clear", variant="secondary", scale=1, min_width=120) # Secondary clear button
236
+
237
+
238
+ # Settings Column
239
+ with gr.Column(scale=1, min_width=250):
240
+ gr.Markdown("### ⚙️ Generation Settings")
241
+ with gr.Group(): # Group settings visually
242
+ max_length = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Length", info="Max tokens in response.")
243
+ temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="0.0 = deterministic\n>1.0 = more random")
244
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P", info="Nucleus sampling probability.")
245
+ top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Limit to K most likely tokens (0=disable).")
246
+
247
+ gr.Markdown("---") # Separator
248
+ gr.Markdown("ℹ️ **Note:** Uses Phi-4 instruction format: \n`<|user|>\nPROMPT<|end|>\n<|assistant|>`")
249
+
250
+
251
+ # Event Listeners (Connecting UI components to functions)
252
+
253
+ # Define reusable inputs list for generation
254
+ gen_inputs = [prompt_input, chatbot, max_length, temperature, top_p, top_k]
255
+
256
+ # Submit action (using streaming yields from generate_response)
257
+ submit_button.click(
258
+ fn=generate_response,
259
+ inputs=gen_inputs,
260
+ outputs=[chatbot], # Output directly streams to chatbot
261
+ queue=True # Enable queuing
262
+ )
263
+ # Allow submitting via Enter key in the textbox as well
264
+ prompt_input.submit(
265
+ fn=generate_response,
266
+ inputs=gen_inputs,
267
+ outputs=[chatbot],
268
+ queue=True
269
  )
270
 
271
+ # Clear button action
272
+ clear_button.click(
273
+ fn=clear_chat,
274
+ inputs=None,
275
+ outputs=[prompt_input, chatbot], # Clear both input and chat history
276
+ queue=False # No need to queue clearing
277
+ )
278
 
279
+ # Launch the Gradio app
280
  logging.info("Launching Gradio App...")
281
+ demo.queue() # Enable queuing for handling concurrent users/requests
282
+ demo.launch(show_error=True, max_threads=40) # show_error=True helps debug in Spaces