Athspi commited on
Commit
eed5424
·
verified ·
1 Parent(s): 1dd4d6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -20
app.py CHANGED
@@ -1,20 +1,205 @@
1
- # Core Gradio UI library
2
- gradio>=4.0.0,<5.0.0
3
-
4
- # ONNX Runtime GenAI for inference (CPU version)
5
- # Use --pre as it's often in pre-release
6
- # NOTE: If targeting a GPU Space, change this to onnxruntime-genai-cuda
7
- # and update EXECUTION_PROVIDER in app.py to "cuda"
8
- onnxruntime-genai --pre
9
-
10
- # Hugging Face Hub for downloading models/files
11
- huggingface_hub>=0.20.0
12
-
13
- # ONNX Runtime itself (will be installed as a dependency of onnxruntime-genai,
14
- # but specifying can sometimes help resolve version conflicts if needed)
15
- # onnxruntime>=1.17.0 # Generally not needed to list explicitly
16
-
17
- # Git LFS is needed by huggingface_hub to download large model files.
18
- # It needs to be installed on the Space environment, which is usually handled
19
- # by the Hugging Face Spaces infrastructure if not using Docker.
20
- # If you encounter download issues, ensure git-lfs is available.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import onnxruntime_genai as og
3
+ import time
4
+ import os
5
+ from huggingface_hub import snapshot_download
6
+ import argparse
7
+ import logging
8
+
9
+ # --- Logging Setup ---
10
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
11
+
12
+ # --- Configuration ---
13
+ MODEL_REPO = "microsoft/Phi-4-mini-instruct-onnx"
14
+
15
+ # --- Defaulting to CPU INT4 for Hugging Face Spaces ---
16
+ # Free Spaces generally provide CPU resources.
17
+ # If deploying on a paid GPU Space, you would change these
18
+ # and the requirements.txt accordingly.
19
+ EXECUTION_PROVIDER = "cpu"
20
+ MODEL_VARIANT_GLOB = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/*"
21
+ # Ensure requirements.txt lists: onnxruntime-genai
22
+ # --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
23
+
24
+ # --- Alternative GPU Configuration (Requires GPU Space & requirements change) ---
25
+ # EXECUTION_PROVIDER = "cuda"
26
+ # MODEL_VARIANT_GLOB = "gpu/gpu-int4-rtn-block-32/*"
27
+ # Ensure requirements.txt lists: onnxruntime-genai-cuda
28
+ # --- --- --- --- --- --- --- --- --- --- --- --- --- --- ---
29
+
30
+ LOCAL_MODEL_DIR = "./phi4-mini-onnx-model" # Directory within the Space
31
+ HF_LOGO_URL = "https://huggingface.co/front/assets/huggingface_logo-noborder.svg" # Official HF Logo
32
+
33
+ # Global variables for model and tokenizer
34
+ model = None
35
+ tokenizer = None
36
+ model_variant_name = os.path.basename(os.path.dirname(MODEL_VARIANT_GLOB)) # For display
37
+
38
+ # --- Model Download and Load ---
39
+ def initialize_model():
40
+ """Downloads and loads the ONNX model and tokenizer."""
41
+ global model, tokenizer
42
+ logging.info("--- Initializing ONNX Runtime GenAI ---")
43
+
44
+ # --- Download ---
45
+ model_variant_dir = os.path.join(LOCAL_MODEL_DIR, os.path.dirname(MODEL_VARIANT_GLOB))
46
+ if os.path.exists(model_variant_dir) and os.listdir(model_variant_dir):
47
+ logging.info(f"Model variant found in {model_variant_dir}. Skipping download.")
48
+ model_path = model_variant_dir
49
+ else:
50
+ logging.info(f"Downloading model variant '{MODEL_VARIANT_GLOB}' from {MODEL_REPO}...")
51
+ try:
52
+ # Use cache_dir for potentially faster re-runs if space storage allows caching
53
+ # cache_dir = os.path.join(LOCAL_MODEL_DIR, ".cache") # Optional: Define cache dir
54
+ snapshot_download(
55
+ MODEL_REPO,
56
+ allow_patterns=[MODEL_VARIANT_GLOB],
57
+ local_dir=LOCAL_MODEL_DIR,
58
+ local_dir_use_symlinks=False # Safest for cross-platform/Space compatibility
59
+ # cache_dir=cache_dir # Optional
60
+ )
61
+ model_path = model_variant_dir
62
+ logging.info(f"Model downloaded to: {model_path}")
63
+ except Exception as e:
64
+ logging.error(f"Error downloading model: {e}", exc_info=True)
65
+ logging.error("Please ensure the Space has internet access and necessary permissions.")
66
+ logging.error("Check Hugging Face Hub status if issues persist.")
67
+ # Optionally raise to stop the app, or try to proceed if partial files might exist
68
+ raise RuntimeError(f"Failed to download model: {e}")
69
+
70
+ # --- Load ---
71
+ logging.info(f"Loading model from: {model_path}")
72
+ logging.info(f"Using Execution Provider: {EXECUTION_PROVIDER.upper()}")
73
+ try:
74
+ og_device_type = getattr(og.DeviceType, EXECUTION_PROVIDER.upper(), og.DeviceType.CPU)
75
+ model = og.Model(model_path, og_device_type)
76
+ tokenizer = og.Tokenizer(model)
77
+ logging.info("Model and Tokenizer loaded successfully.")
78
+ except Exception as e:
79
+ logging.error(f"Error loading model or tokenizer: {e}", exc_info=True)
80
+ logging.error(f"Ensure the correct onnxruntime-genai package is installed (check requirements.txt) for {EXECUTION_PROVIDER}.")
81
+ logging.error("Verify model files integrity in '{model_path}'.")
82
+ raise RuntimeError(f"Failed to load model: {e}")
83
+
84
+ # --- Generation Function ---
85
+ def generate_response(prompt, history, max_length=1024, temperature=0.7, top_p=0.9, top_k=50):
86
+ """Generates a response using the Phi-4 ONNX model."""
87
+ if not model or not tokenizer:
88
+ return "Error: Model not initialized. Please check logs."
89
+ if not prompt:
90
+ return "Please enter a prompt."
91
+
92
+ # --- Prepare the prompt using the Phi-4 instruct format ---
93
+ # "<|user|>\n{user_message}<|end|>\n<|assistant|>\n{assistant_message}<|end|>"
94
+ full_prompt = ""
95
+ for user_msg, assistant_msg in history:
96
+ full_prompt += f"<|user|>\n{user_msg}<|end|>\n"
97
+ if assistant_msg: # Add assistant message only if it exists
98
+ full_prompt += f"<|assistant|>\n{assistant_msg}<|end|>\n"
99
+
100
+ # Add the current user prompt and the trigger for the assistant's response
101
+ full_prompt += f"<|user|>\n{prompt}<|end|>\n<|assistant|>\n"
102
+
103
+ logging.info(f"Generating response for prompt (last part): ...{prompt[-50:]}")
104
+ # logging.debug(f"Full Formatted Prompt:\n{full_prompt}") # Use debug level
105
+
106
+ try:
107
+ input_tokens = tokenizer.encode(full_prompt)
108
+
109
+ search_options = {
110
+ "max_length": max_length,
111
+ "temperature": temperature,
112
+ "top_p": top_p,
113
+ "top_k": top_k,
114
+ "do_sample": True, # Sampling is generally preferred for chat
115
+ "eos_token_id": tokenizer.eos_token_id, # Important for stopping generation
116
+ "pad_token_id": tokenizer.pad_token_id if hasattr(tokenizer, 'pad_token_id') else tokenizer.eos_token_id, # Use EOS if PAD not explicit
117
+ }
118
+
119
+ params = og.GeneratorParams(model)
120
+ params.set_search_options(**search_options)
121
+ params.input_ids = input_tokens
122
+
123
+ start_time = time.time()
124
+ generator = og.Generator(model, params)
125
+ response_text = ""
126
+ logging.info("Streaming response...")
127
+
128
+ # Simple token streaming - yield partial results for Gradio
129
+ while not generator.is_done():
130
+ generator.compute_logits()
131
+ generator.generate_next_token()
132
+ next_token = generator.get_next_tokens()[0]
133
+ # Important: Check for EOS token ID to stop manually if needed
134
+ if next_token == search_options["eos_token_id"]:
135
+ break
136
+ decoded_chunk = tokenizer.decode([next_token])
137
+ response_text += decoded_chunk
138
+ yield response_text # Yield intermediate results for streaming effect
139
+
140
+ end_time = time.time()
141
+ logging.info(f"Generation complete. Time taken: {end_time - start_time:.2f} seconds")
142
+ logging.info(f"Full Response (last 100 chars): ...{response_text[-100:]}")
143
+
144
+ # Final yield with the complete text (or return if not using yield in Gradio setup)
145
+ yield response_text.strip()
146
+
147
+ except Exception as e:
148
+ logging.error(f"Error during generation: {e}", exc_info=True)
149
+ yield f"Sorry, an error occurred during generation: {e}" # Yield error message
150
+
151
+
152
+ # --- Initialize Model on App Start ---
153
+ try:
154
+ initialize_model()
155
+ except Exception as e:
156
+ print(f"FATAL: Model initialization failed: {e}")
157
+ # Optionally create a dummy Gradio interface showing the error
158
+ # Or just let the script exit/fail in the Space environment
159
+
160
+ # --- Gradio Interface ---
161
+ logging.info("Creating Gradio Interface...")
162
+
163
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
164
+ gr.Markdown(f"""
165
+ # Phi-4 Mini Instruct ONNX Demo
166
+ Powered by [`onnxruntime-genai`](https://github.com/microsoft/onnxruntime-genai) running the **{EXECUTION_PROVIDER.upper()}** INT4 ONNX Runtime version of [`{MODEL_REPO}`](https://huggingface.co/{MODEL_REPO}).
167
+ Model Variant: `{model_variant_name}`
168
+
169
+ <img src="{HF_LOGO_URL}" alt="Hugging Face Logo" style="display: inline-block; height: 1.5em; vertical-align: middle;"> This Space demonstrates running Phi-4 Mini efficiently with ONNX Runtime.
170
+ """)
171
+
172
+ chatbot = gr.Chatbot(label="Chat History", height=500, layout="bubble", bubble_full_width=False)
173
+ msg = gr.Textbox(
174
+ label="Your Prompt",
175
+ placeholder="<|user|>\nType your message here...<|end|>\n<|assistant|>",
176
+ lines=3,
177
+ info="Using the recommended Phi-4 instruct format." # Add info text
178
+ )
179
+ clear = gr.Button("Clear Chat")
180
+
181
+ with gr.Accordion("Generation Parameters", open=False):
182
+ max_length = gr.Slider(minimum=64, maximum=4096, value=1024, step=64, label="Max Length", info="Maximum number of tokens to generate.")
183
+ temperature = gr.Slider(minimum=0.0, maximum=1.5, value=0.7, step=0.05, label="Temperature", info="Higher values (e.g., 0.8) make output more random, lower values (e.g., 0.2) make it more deterministic.")
184
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.05, label="Top-P (Nucleus Sampling)", info="Filters vocabulary to the smallest set whose cumulative probability exceeds P. Set to 1.0 to disable.")
185
+ top_k = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Top-K", info="Filters vocabulary to the K most likely tokens. Set to 0 to disable.")
186
+
187
+ # Use the streaming capability of ChatInterface
188
+ # msg.submit returns a generator which updates chatbot gradually
189
+ msg.submit(
190
+ generate_response,
191
+ inputs=[msg, chatbot, max_length, temperature, top_p, top_k],
192
+ outputs=[chatbot]
193
+ )
194
+
195
+ # Connect the clear button
196
+ clear.click(lambda: (None, None), None, [msg, chatbot], queue=False) # Clear input and chatbot
197
+
198
+ gr.Markdown("Enter your prompt using the suggested format and press Enter. Adjust generation parameters in the accordion above.")
199
+
200
+ logging.info("Launching Gradio App...")
201
+ # Setting share=False is default and recommended for Spaces
202
+ # queue() is important for handling multiple users
203
+ # debug=True can be useful for local testing but should generally be False in production/Spaces
204
+ demo.queue()
205
+ demo.launch(show_error=True) # Show errors in the UI for easier debugging in Spaces