Update app.py
Browse files
app.py
CHANGED
@@ -1,46 +1,61 @@
|
|
1 |
import gradio as gr
|
|
|
2 |
from transformers import AutoTokenizer
|
3 |
import onnxruntime as ort
|
4 |
-
import numpy as np
|
5 |
|
6 |
-
#
|
7 |
-
|
|
|
|
|
8 |
|
9 |
-
#
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
12 |
|
13 |
-
# Inference function
|
14 |
def generate_response(prompt):
|
|
|
15 |
full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
|
|
|
|
|
|
|
16 |
inputs = tokenizer(full_prompt, return_tensors="np")
|
17 |
-
|
18 |
-
# ONNX
|
19 |
ort_inputs = {
|
20 |
"input_ids": inputs["input_ids"].astype(np.int64),
|
21 |
"attention_mask": inputs["attention_mask"].astype(np.int64)
|
22 |
}
|
23 |
-
|
24 |
-
# Run model
|
25 |
outputs = session.run(None, ort_inputs)
|
|
|
|
|
|
|
26 |
generated_ids = outputs[0]
|
27 |
-
|
28 |
-
# Decode
|
29 |
response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
30 |
-
|
31 |
-
#
|
32 |
if "<|assistant|>" in response:
|
33 |
response = response.split("<|assistant|>")[-1].strip()
|
|
|
34 |
return response
|
35 |
|
36 |
-
# Gradio interface
|
37 |
-
|
38 |
fn=generate_response,
|
39 |
inputs=gr.Textbox(label="Your Prompt", placeholder="Type your question here...", lines=4),
|
40 |
outputs=gr.Textbox(label="AI Response"),
|
41 |
title="Phi-4-Mini ONNX Chatbot",
|
42 |
-
description=
|
|
|
|
|
|
|
43 |
)
|
44 |
|
45 |
-
# Launch the app
|
46 |
-
|
|
|
1 |
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
from transformers import AutoTokenizer
|
4 |
import onnxruntime as ort
|
|
|
5 |
|
6 |
+
# Load the tokenizer from the Hugging Face hub.
|
7 |
+
# This loads files like `tokenizer.json`, `vocab.json`, etc. from the repository root.
|
8 |
+
model_repo = "microsoft/Phi-4-mini-instruct-onnx"
|
9 |
+
tokenizer = AutoTokenizer.from_pretrained(model_repo)
|
10 |
|
11 |
+
# Specify the relative path to the ONNX model files stored in the repository subfolder.
|
12 |
+
# You need to have downloaded these LFS files either locally or ensure your environment can access them.
|
13 |
+
onnx_model_path = "cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4/model.onnx"
|
14 |
+
|
15 |
+
# Create an ONNX Runtime session.
|
16 |
+
session = ort.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
|
17 |
|
|
|
18 |
def generate_response(prompt):
|
19 |
+
# Prepare the prompt with a simple instruction format.
|
20 |
full_prompt = f"<|user|>\n{prompt}\n<|assistant|>\n"
|
21 |
+
|
22 |
+
# Tokenize the input.
|
23 |
+
# The tokenizer returns NumPy arrays (using return_tensors="np").
|
24 |
inputs = tokenizer(full_prompt, return_tensors="np")
|
25 |
+
|
26 |
+
# ONNX runtime requires inputs of type int64.
|
27 |
ort_inputs = {
|
28 |
"input_ids": inputs["input_ids"].astype(np.int64),
|
29 |
"attention_mask": inputs["attention_mask"].astype(np.int64)
|
30 |
}
|
31 |
+
|
32 |
+
# Run the model inference.
|
33 |
outputs = session.run(None, ort_inputs)
|
34 |
+
|
35 |
+
# Assuming the model returns logits or generated IDs in the first element.
|
36 |
+
# Here we assume the model output contains generated token IDs.
|
37 |
generated_ids = outputs[0]
|
38 |
+
|
39 |
+
# Decode the generated token IDs into text.
|
40 |
response = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
|
41 |
+
|
42 |
+
# Optionally, remove earlier prompt parts if your model returns the input tokens as well.
|
43 |
if "<|assistant|>" in response:
|
44 |
response = response.split("<|assistant|>")[-1].strip()
|
45 |
+
|
46 |
return response
|
47 |
|
48 |
+
# Create a Gradio interface to interact with the model.
|
49 |
+
interface = gr.Interface(
|
50 |
fn=generate_response,
|
51 |
inputs=gr.Textbox(label="Your Prompt", placeholder="Type your question here...", lines=4),
|
52 |
outputs=gr.Textbox(label="AI Response"),
|
53 |
title="Phi-4-Mini ONNX Chatbot",
|
54 |
+
description=(
|
55 |
+
"Chat interface powered by microsoft/Phi-4-mini-instruct-onnx. "
|
56 |
+
"The ONNX model is loaded from the int4-optimized subfolder (cpu_and_mobile/cpu-int4-rtn-block-32-acc-level-4)."
|
57 |
+
)
|
58 |
)
|
59 |
|
60 |
+
# Launch the Gradio app.
|
61 |
+
interface.launch()
|