Athspi commited on
Commit
92acddd
·
verified ·
1 Parent(s): 346197d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -50
app.py CHANGED
@@ -1,61 +1,227 @@
 
 
1
  import gradio as gr
2
- import numpy as np
3
  from transformers import AutoTokenizer
4
  import onnxruntime as ort
 
 
 
 
 
5
 
6
- # Load the tokenizer from the Hugging Face hub.
7
- # This loads files like `tokenizer.json`, `vocab.json`, etc. from the repository root.
8
- model_repo = "microsoft/Phi-4-mini-instruct-onnx"
9
- tokenizer = AutoTokenizer.from_pretrained(model_repo)
10
 
11
- # Specify the relative path to the ONNX model files stored in the repository subfolder.
12
- # You need to have downloaded these LFS files either locally or ensure your environment can access them.
13
- onnx_model_path = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/model.onnx"
 
14
 
15
- # Create an ONNX Runtime session.
16
- session = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def generate_response(prompt):
19
- # Prepare the prompt with a simple instruction format.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
21
-
22
- # Tokenize the input.
23
- # The tokenizer returns NumPy arrays (using return_tensors="np").
24
- inputs = tokenizer(full_prompt, return_tensors="np")
25
-
26
- # ONNX runtime requires inputs of type int64.
27
- ort_inputs = {
28
- "input_ids": inputs["input_ids"].astype(np.int64),
29
- "attention_mask": inputs["attention_mask"].astype(np.int64)
30
- }
31
-
32
- # Run the model inference.
33
- outputs = session.run(None, ort_inputs)
34
-
35
- # Assuming the model returns logits or generated IDs in the first element.
36
- # Here we assume the model output contains generated token IDs.
37
- generated_ids = outputs[0]
38
-
39
- # Decode the generated token IDs into text.
40
- response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
41
-
42
- # Optionally, remove earlier prompt parts if your model returns the input tokens as well.
43
- if "<|assistant|>" in response:
44
- response = response.split("<|assistant|>")[-1].strip()
45
-
46
- return response
47
-
48
- # Create a Gradio interface to interact with the model.
49
- interface = gr.Interface(
50
- fn=generate_response,
51
- inputs=gr.Textbox(label="Your Prompt", placeholder="Type your question here...", lines=4),
52
- outputs=gr.Textbox(label="AI Response"),
53
- title="Phi-4-Mini ONNX Chatbot",
54
- description=(
55
- "Chat interface powered by microsoft/Phi-4-mini-instruct-onnx. "
56
- "The ONNX model is loaded from the int4-optimized subfolder (cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4)."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  )
58
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- # Launch the Gradio app.
61
- interface.launch()
 
1
+ # app.py
2
+
3
  import gradio as gr
 
4
  from transformers import AutoTokenizer
5
  import onnxruntime as ort
6
+ import numpy as np
7
+ import os # Import os module to check if model directory exists
8
+ import time # To measure performance (optional)
9
+
10
+ print("Loading libraries...")
11
 
12
+ # --- Configuration ---
13
+ # Define the local directory where the downloaded model files are stored.
14
+ # This path MUST match where you downloaded the model files relative to this script.
15
+ model_dir = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4"
16
 
17
+ # --- Model Loading ---
18
+ tokenizer = None
19
+ session = None
20
+ model_load_error = None
21
 
22
+ # Check if the model directory exists before attempting to load
23
+ if not os.path.isdir(model_dir):
24
+ model_load_error = (
25
+ f"Error: Model directory not found at '{os.path.abspath(model_dir)}'\n"
26
+ "Please ensure you have created the directory structure\n"
27
+ f"'./{model_dir}' relative to this script ({os.path.basename(__file__)})\n"
28
+ "and downloaded ALL the required model files into it from:\n"
29
+ "https://huggingface.co/microsoft/Phi-4-mini-instruct-onnx/tree/main/cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4"
30
+ )
31
+ print(model_load_error)
32
+ else:
33
+ print(f"Found model directory: {os.path.abspath(model_dir)}")
34
+ print("Loading tokenizer...")
35
+ try:
36
+ # Load tokenizer associated with the Phi-4 model variant
37
+ tokenizer = AutoTokenizer.from_pretrained(model_dir)
38
+ print("Tokenizer loaded successfully.")
39
+ except Exception as e:
40
+ model_load_error = f"Error loading tokenizer from {model_dir}: {e}"
41
+ print(model_load_error)
42
+
43
+ # Only attempt to load session if tokenizer loaded successfully
44
+ if tokenizer:
45
+ print("Loading ONNX model session...")
46
+ model_path = os.path.join(model_dir, "model.onnx")
47
+ model_data_path = os.path.join(model_dir, "model.onnx.data")
48
 
49
+ if not os.path.exists(model_path):
50
+ model_load_error = f"Error: 'model.onnx' not found in {model_dir}"
51
+ print(model_load_error)
52
+ elif not os.path.exists(model_data_path):
53
+ model_load_error = f"Error: 'model.onnx.data' not found in {model_dir}. This large file contains the model weights and is required."
54
+ print(model_load_error)
55
+ else:
56
+ try:
57
+ # Load the ONNX model using ONNX Runtime for CPU execution
58
+ start_time = time.time()
59
+ # You can configure session options for performance if needed
60
+ # sess_options = ort.SessionOptions()
61
+ # sess_options.intra_op_num_threads = 4 # Example: Limit threads
62
+ session = ort.InferenceSession(
63
+ model_path,
64
+ providers=["CPUExecutionProvider"]
65
+ # sess_options=sess_options # Uncomment to use options
66
+ )
67
+ end_time = time.time()
68
+ print(f"ONNX model session loaded successfully using CPU provider in {end_time - start_time:.2f} seconds.")
69
+ except Exception as e:
70
+ model_load_error = f"Error loading ONNX session from {model_path}: {e}\n"
71
+ model_load_error += "Ensure 'onnxruntime' library is installed correctly and that both 'model.onnx' and 'model.onnx.data' are valid files."
72
+ print(model_load_error)
73
+
74
+ # --- Inference Function ---
75
  def generate_response(prompt):
76
+ """
77
+ Generates a response from the loaded ONNX model based on the user prompt.
78
+ """
79
+ global tokenizer, session, model_load_error # Allow access to global vars
80
+
81
+ # Check if model loading failed earlier
82
+ if model_load_error:
83
+ return model_load_error
84
+ if not tokenizer or not session:
85
+ return "Error: Model or Tokenizer is not loaded correctly. Check console output."
86
+
87
+ print(f"\nReceived prompt: {prompt}")
88
+ start_time = time.time()
89
+
90
+ # Format the prompt with specific markers for instruction following
91
  full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
92
+ print("Tokenizing input...")
93
+
94
+ try:
95
+ # Tokenize the formatted prompt
96
+ inputs = tokenizer(full_prompt, return_tensors="np")
97
+
98
+ # Prepare inputs for the ONNX model
99
+ ort_inputs = {
100
+ "input_ids": inputs["input_ids"].astype(np.int64),
101
+ "attention_mask": inputs["attention_mask"].astype(np.int64)
102
+ }
103
+ print("Running model inference...")
104
+ inference_start_time = time.time()
105
+
106
+ # Run the ONNX model inference
107
+ outputs = session.run(None, ort_inputs)
108
+ generated_ids = outputs[0] # Assuming the first output contains the generated IDs
109
+
110
+ inference_end_time = time.time()
111
+ print(f"Inference complete in {inference_end_time - inference_start_time:.2f} seconds.")
112
+
113
+ # Decode the generated token IDs back into text
114
+ print("Decoding response...")
115
+ decoding_start_time = time.time()
116
+ # Ensure generated_ids is 1D if necessary, might be shape (1, sequence_length)
117
+ output_ids = generated_ids[0] if generated_ids.ndim == 2 else generated_ids
118
+ response = tokenizer.decode(output_ids, skip_special_tokens=True)
119
+ decoding_end_time = time.time()
120
+ print(f"Decoding complete in {decoding_end_time - decoding_start_time:.2f} seconds.")
121
+
122
+ # --- Response Cleaning ---
123
+ # 1. Find the start of the assistant's response
124
+ assistant_marker = "<|assistant|>"
125
+ assistant_pos = response.find(assistant_marker)
126
+
127
+ if assistant_pos != -1:
128
+ # If marker found, take text after it
129
+ cleaned_response = response[assistant_pos + len(assistant_marker):].strip()
130
+ else:
131
+ # Fallback: If marker isn't perfectly decoded, try removing the original input prompt
132
+ # This assumes the model might prepend the input sometimes.
133
+ # Remove the prompt part *without* the final <|assistant|> tag
134
+ prompt_part_to_remove = full_prompt.rsplit(assistant_marker, 1)[0]
135
+ if response.startswith(prompt_part_to_remove):
136
+ cleaned_response = response[len(prompt_part_to_remove):].strip()
137
+ else:
138
+ # If neither works well, return the raw response (might contain parts of the prompt)
139
+ cleaned_response = response.strip()
140
+ print("Warning: Could not reliably clean the prompt context from the response.")
141
+
142
+
143
+ total_time = time.time() - start_time
144
+ print(f"Generated response: {cleaned_response}")
145
+ print(f"Total processing time for this prompt: {total_time:.2f} seconds.")
146
+ return cleaned_response
147
+
148
+ except Exception as e:
149
+ print(f"Error during model inference or decoding: {e}")
150
+ import traceback
151
+ traceback.print_exc() # Print detailed traceback for debugging
152
+ return f"Error during generation: {e}"
153
+
154
+ # --- Gradio Interface Setup ---
155
+ print("Setting up Gradio interface...")
156
+
157
+ # Define CSS for better layout (optional)
158
+ css = """
159
+ #output_textbox textarea {
160
+ min-height: 300px; /* Make output box taller */
161
+ }
162
+ #input_textbox textarea {
163
+ min-height: 100px; /* Adjust input box height */
164
+ }
165
+ """
166
+
167
+ demo = gr.Blocks(css=css, theme=gr.themes.Default()) # Use Blocks for more layout control
168
+
169
+ with demo:
170
+ gr.Markdown(
171
+ """
172
+ # Phi-4-Mini ONNX Chatbot (Local CPU)
173
+ Interact with the `microsoft/Phi-4-mini-instruct-onnx` model variant
174
+ (`cpu-int4-rtn-block-32-acc-level-4`) running locally using ONNX Runtime on your CPU.
175
+ """
176
  )
177
+ with gr.Row():
178
+ with gr.Column(scale=2): # Input column
179
+ input_textbox = gr.Textbox(
180
+ label="Your Prompt",
181
+ placeholder="Type your question or instruction here...",
182
+ lines=4, # Initial lines, resizable
183
+ elem_id="input_textbox" # Assign ID for CSS
184
+ )
185
+ submit_button = gr.Button("Generate Response", variant="primary")
186
+ with gr.Column(scale=3): # Output column
187
+ output_textbox = gr.Textbox(
188
+ label="AI Response",
189
+ lines=10, # Initial lines, resizable
190
+ interactive=False, # User cannot type in the output box
191
+ elem_id="output_textbox" # Assign ID for CSS
192
+ )
193
+
194
+ # Display model loading status/errors
195
+ if model_load_error:
196
+ gr.Markdown(f"**<font color='red'>Model Loading Error:</font>**\n```\n{model_load_error}\n```")
197
+ elif session is None or tokenizer is None:
198
+ gr.Markdown("**<font color='orange'>Warning:</font>** Model or tokenizer did not load correctly. Check console logs.")
199
+ else:
200
+ gr.Markdown("**<font color='green'>Model and Tokenizer Loaded Successfully.</font>**")
201
+
202
+ # Connect button click to the function
203
+ submit_button.click(
204
+ fn=generate_response,
205
+ inputs=input_textbox,
206
+ outputs=output_textbox
207
+ )
208
+
209
+ # Allow submitting by pressing Enter in the input textbox
210
+ input_textbox.submit(
211
+ fn=generate_response,
212
+ inputs=input_textbox,
213
+ outputs=output_textbox
214
+ )
215
+
216
+ # --- Launch the Application ---
217
+ print("-" * 50)
218
+ print("Launching Gradio app...")
219
+ print("You can access it in your browser at the URL provided below (usually http://127.0.0.1:7860).")
220
+ print("Press CTRL+C in this terminal to stop the application.")
221
+ print("-" * 50)
222
+
223
+ # share=True creates a temporary public link (use with caution).
224
+ # Set debug=True for more detailed Gradio errors if needed.
225
+ demo.launch(share=False, debug=False)
226
 
227
+ print("Gradio app closed.")