Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,16 +1,21 @@
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
4 |
-
|
5 |
-
from
|
|
|
|
|
6 |
|
7 |
# --- Configuration ---
|
8 |
-
MODEL_ID = "onnx-community/gemma-3-1b-it-ONNX-GQA"
|
9 |
-
ONNX_FILE_NAME = None
|
10 |
|
11 |
print(f"Using Transformers version: {transformers_version}")
|
12 |
-
|
13 |
-
print(f"Using
|
|
|
|
|
|
|
14 |
|
15 |
# --- Device Selection ---
|
16 |
try:
|
@@ -43,25 +48,23 @@ try:
|
|
43 |
model = ORTModelForCausalLM.from_pretrained(
|
44 |
MODEL_ID,
|
45 |
provider=provider,
|
46 |
-
use_cache=True,
|
47 |
-
# use_io_binding=False # GPU μ¬μ© μ λ¬Έμ λ°μνλ©΄ False λ‘ μλ
|
48 |
)
|
49 |
print(f"ONNX Model '{MODEL_ID}' loaded successfully with provider '{provider}'.")
|
50 |
model_loaded_successfully = True
|
51 |
|
52 |
except ValueError as ve:
|
53 |
-
#
|
54 |
print(f"!!!!!!!!!!!!!! CRITICAL MODEL LOADING ERROR (ValueError) !!!!!!!!!!!!!!")
|
55 |
print(f"Model: {MODEL_ID}")
|
56 |
print(f"Error message: {ve}")
|
57 |
print("This likely means the installed 'transformers' library version does NOT support the 'gemma3_text' architecture.")
|
58 |
print("Ensure 'requirements.txt' specifies a recent version (e.g., transformers>=4.41.0) and the Space has been rebuilt/restarted.")
|
59 |
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
60 |
-
# λͺ¨λΈ λ‘λ© μ€ν¨ μ μ¬μ©μμκ² λͺ
νν μλ¦Ό
|
61 |
model_loaded_successfully = False
|
62 |
|
63 |
except Exception as e:
|
64 |
-
# λ€λ₯Έ
|
65 |
print(f"!!!!!!!!!!!!!! UNEXPECTED MODEL LOADING ERROR !!!!!!!!!!!!!!")
|
66 |
print(f"Model: {MODEL_ID}")
|
67 |
print(f"Error type: {type(e).__name__}")
|
@@ -73,33 +76,23 @@ except Exception as e:
|
|
73 |
# --- Chat Function ---
|
74 |
def chat_function(message: str, history: list):
|
75 |
if not model_loaded_successfully or model is None or tokenizer is None:
|
76 |
-
# λͺ¨λΈ λ‘λ μ€ν¨ μ μ€λ₯ λ©μμ§ λ°ν
|
77 |
return "Error: The AI model is not loaded. Please check the application logs."
|
78 |
|
79 |
try:
|
80 |
-
# μ±ν
|
81 |
chat_messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
|
82 |
for user_msg, model_msg in history:
|
83 |
-
|
84 |
-
if
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
if message: # νμ¬ μ¬μ©μ λ©μμ§ μΆκ°
|
89 |
-
chat_messages.append({"role": "user", "content": message})
|
90 |
-
|
91 |
-
# ν둬ννΈ μμ± (apply_chat_template μλ, μ€ν¨ μ μλ)
|
92 |
prompt = ""
|
93 |
try:
|
94 |
-
prompt = tokenizer.apply_chat_template(
|
95 |
-
chat_messages,
|
96 |
-
tokenize=False,
|
97 |
-
add_generation_prompt=True
|
98 |
-
)
|
99 |
except Exception as template_error:
|
100 |
print(f"Warning: Failed to apply chat template ({template_error}). Using manual prompt construction.")
|
101 |
prompt_parts = ["<start_of_turn>system\nYou are a helpful AI assistant.<end_of_turn>"]
|
102 |
-
# history μμ model λ©μμ§κ° None μΌ μ μμμ μ μ
|
103 |
for user_msg, model_msg in history:
|
104 |
if user_msg: prompt_parts.append(f"<start_of_turn>user\n{user_msg}<end_of_turn>")
|
105 |
if model_msg: prompt_parts.append(f"<start_of_turn>model\n{model_msg}<end_of_turn>")
|
@@ -107,14 +100,12 @@ def chat_function(message: str, history: list):
|
|
107 |
prompt_parts.append("<start_of_turn>model")
|
108 |
prompt = "\n".join(prompt_parts)
|
109 |
|
110 |
-
#
|
111 |
-
|
112 |
-
# μ
λ ₯ ν ν°ν λ° λλ°μ΄μ€ μ΄λ
|
113 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
114 |
|
115 |
# μλ΅ μμ±
|
116 |
print("Generating response...")
|
117 |
-
with torch.no_grad():
|
118 |
outputs = model.generate(
|
119 |
**inputs,
|
120 |
max_new_tokens=512,
|
@@ -122,50 +113,34 @@ def chat_function(message: str, history: list):
|
|
122 |
temperature=0.7,
|
123 |
top_k=50,
|
124 |
top_p=0.9,
|
125 |
-
pad_token_id=tokenizer.eos_token_id
|
126 |
)
|
127 |
print("Generation complete.")
|
128 |
|
129 |
-
# λμ½λ©
|
130 |
input_token_len = inputs['input_ids'].shape[1]
|
131 |
generated_tokens = outputs[0][input_token_len:]
|
132 |
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
133 |
-
|
134 |
-
# νμ²λ¦¬
|
135 |
response = response.replace("<end_of_turn>", "").strip()
|
136 |
-
|
137 |
-
# print(f"--- RESPONSE --- \n{response}\n--------------")
|
138 |
-
|
139 |
-
# λΉ μλ΅ μ²λ¦¬
|
140 |
if not response:
|
141 |
print("Warning: Generated empty response.")
|
142 |
response = "Sorry, I couldn't generate a response for that."
|
143 |
-
|
144 |
return response
|
145 |
|
146 |
except Exception as e:
|
147 |
print(f"!!!!!!!!!!!!!! Error during generation !!!!!!!!!!!!!!")
|
148 |
print(f"Error type: {type(e).__name__}")
|
149 |
print(f"Error message: {e}")
|
150 |
-
print("Input message:", message)
|
151 |
-
# traceback.print_exc() # νμμ μμΈ νΈλ μ΄μ€λ°± μΆλ ₯
|
152 |
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
153 |
return f"Sorry, an error occurred during response generation. Please check logs."
|
154 |
|
155 |
-
|
156 |
-
# --- Gradio Interface (μμ λ¨) ---
|
157 |
print("Creating Gradio Interface...")
|
158 |
iface = gr.ChatInterface(
|
159 |
-
fn=chat_function,
|
160 |
title="AI Assistant (Gemma 3 1B ONNX-GQA)",
|
161 |
description=f"Chat with {MODEL_ID}. Model loaded: {model_loaded_successfully}",
|
162 |
-
# chatbot μμ ―μ type='messages' μΆκ°
|
163 |
chatbot=gr.Chatbot(height=600, type="messages", bubble_full_width=False),
|
164 |
-
# μ§μνμ§ μλ λ²νΌ μΈμ μ κ±°
|
165 |
-
# retry_btn=None, # μ κ±°
|
166 |
-
# undo_btn=None, # μ κ±°
|
167 |
-
# clear_btn=None, # μ κ±°
|
168 |
-
# submit_btn λμ κΈ°λ³Έ λ²νΌ μ¬μ©
|
169 |
theme=gr.themes.Soft(),
|
170 |
examples=[["Hello!"], ["Write a poem about the internet."]]
|
171 |
)
|
@@ -173,5 +148,4 @@ iface = gr.ChatInterface(
|
|
173 |
# --- Launch App ---
|
174 |
if __name__ == "__main__":
|
175 |
print("Launching Gradio App...")
|
176 |
-
# λͺ¨λΈ λ‘λ© μ€ν¨ μμλ μΈν°νμ΄μ€λ μ€ννλ, μ€λ₯ λ©μμ§ νμ
|
177 |
iface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import torch
|
3 |
import os
|
4 |
+
# optimum.onnxruntime μμ __version__ import μ κ±°
|
5 |
+
from transformers import AutoTokenizer, __version__ as transformers_version
|
6 |
+
from optimum.onnxruntime import ORTModelForCausalLM
|
7 |
+
# import optimum # optimum μ체μ λ²μ νμΈ μλ (μ νμ )
|
8 |
|
9 |
# --- Configuration ---
|
10 |
+
MODEL_ID = "onnx-community/gemma-3-1b-it-ONNX-GQA"
|
11 |
+
ONNX_FILE_NAME = None
|
12 |
|
13 |
print(f"Using Transformers version: {transformers_version}")
|
14 |
+
# try:
|
15 |
+
# print(f"Using Optimum version: {optimum.__version__}") # λ€λ₯Έ λ°©λ²μΌλ‘ λ²μ νμΈ μλ
|
16 |
+
# except AttributeError:
|
17 |
+
# print("Could not determine Optimum version automatically.")
|
18 |
+
print(f"Using Gradio version: {gr.__version__}")
|
19 |
|
20 |
# --- Device Selection ---
|
21 |
try:
|
|
|
48 |
model = ORTModelForCausalLM.from_pretrained(
|
49 |
MODEL_ID,
|
50 |
provider=provider,
|
51 |
+
use_cache=True,
|
|
|
52 |
)
|
53 |
print(f"ONNX Model '{MODEL_ID}' loaded successfully with provider '{provider}'.")
|
54 |
model_loaded_successfully = True
|
55 |
|
56 |
except ValueError as ve:
|
57 |
+
# λͺ¨λΈ νμ
λ―Έμ§μ μ€λ₯ μ²λ¦¬
|
58 |
print(f"!!!!!!!!!!!!!! CRITICAL MODEL LOADING ERROR (ValueError) !!!!!!!!!!!!!!")
|
59 |
print(f"Model: {MODEL_ID}")
|
60 |
print(f"Error message: {ve}")
|
61 |
print("This likely means the installed 'transformers' library version does NOT support the 'gemma3_text' architecture.")
|
62 |
print("Ensure 'requirements.txt' specifies a recent version (e.g., transformers>=4.41.0) and the Space has been rebuilt/restarted.")
|
63 |
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
|
|
64 |
model_loaded_successfully = False
|
65 |
|
66 |
except Exception as e:
|
67 |
+
# λ€λ₯Έ μμΈ μ²λ¦¬
|
68 |
print(f"!!!!!!!!!!!!!! UNEXPECTED MODEL LOADING ERROR !!!!!!!!!!!!!!")
|
69 |
print(f"Model: {MODEL_ID}")
|
70 |
print(f"Error type: {type(e).__name__}")
|
|
|
76 |
# --- Chat Function ---
|
77 |
def chat_function(message: str, history: list):
|
78 |
if not model_loaded_successfully or model is None or tokenizer is None:
|
|
|
79 |
return "Error: The AI model is not loaded. Please check the application logs."
|
80 |
|
81 |
try:
|
82 |
+
# μ±ν
κΈ°λ‘ λ³ν
|
83 |
chat_messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
|
84 |
for user_msg, model_msg in history:
|
85 |
+
if user_msg: chat_messages.append({"role": "user", "content": user_msg})
|
86 |
+
if model_msg: chat_messages.append({"role": "model", "content": model_msg})
|
87 |
+
if message: chat_messages.append({"role": "user", "content": message})
|
88 |
+
|
89 |
+
# ν둬ννΈ μμ±
|
|
|
|
|
|
|
|
|
90 |
prompt = ""
|
91 |
try:
|
92 |
+
prompt = tokenizer.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True)
|
|
|
|
|
|
|
|
|
93 |
except Exception as template_error:
|
94 |
print(f"Warning: Failed to apply chat template ({template_error}). Using manual prompt construction.")
|
95 |
prompt_parts = ["<start_of_turn>system\nYou are a helpful AI assistant.<end_of_turn>"]
|
|
|
96 |
for user_msg, model_msg in history:
|
97 |
if user_msg: prompt_parts.append(f"<start_of_turn>user\n{user_msg}<end_of_turn>")
|
98 |
if model_msg: prompt_parts.append(f"<start_of_turn>model\n{model_msg}<end_of_turn>")
|
|
|
100 |
prompt_parts.append("<start_of_turn>model")
|
101 |
prompt = "\n".join(prompt_parts)
|
102 |
|
103 |
+
# μ
λ ₯ ν ν°ν
|
|
|
|
|
104 |
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
105 |
|
106 |
# μλ΅ μμ±
|
107 |
print("Generating response...")
|
108 |
+
with torch.no_grad():
|
109 |
outputs = model.generate(
|
110 |
**inputs,
|
111 |
max_new_tokens=512,
|
|
|
113 |
temperature=0.7,
|
114 |
top_k=50,
|
115 |
top_p=0.9,
|
116 |
+
pad_token_id=tokenizer.eos_token_id
|
117 |
)
|
118 |
print("Generation complete.")
|
119 |
|
120 |
+
# λμ½λ©
|
121 |
input_token_len = inputs['input_ids'].shape[1]
|
122 |
generated_tokens = outputs[0][input_token_len:]
|
123 |
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
|
|
|
|
|
124 |
response = response.replace("<end_of_turn>", "").strip()
|
|
|
|
|
|
|
|
|
125 |
if not response:
|
126 |
print("Warning: Generated empty response.")
|
127 |
response = "Sorry, I couldn't generate a response for that."
|
|
|
128 |
return response
|
129 |
|
130 |
except Exception as e:
|
131 |
print(f"!!!!!!!!!!!!!! Error during generation !!!!!!!!!!!!!!")
|
132 |
print(f"Error type: {type(e).__name__}")
|
133 |
print(f"Error message: {e}")
|
|
|
|
|
134 |
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
|
135 |
return f"Sorry, an error occurred during response generation. Please check logs."
|
136 |
|
137 |
+
# --- Gradio Interface ---
|
|
|
138 |
print("Creating Gradio Interface...")
|
139 |
iface = gr.ChatInterface(
|
140 |
+
fn=chat_function,
|
141 |
title="AI Assistant (Gemma 3 1B ONNX-GQA)",
|
142 |
description=f"Chat with {MODEL_ID}. Model loaded: {model_loaded_successfully}",
|
|
|
143 |
chatbot=gr.Chatbot(height=600, type="messages", bubble_full_width=False),
|
|
|
|
|
|
|
|
|
|
|
144 |
theme=gr.themes.Soft(),
|
145 |
examples=[["Hello!"], ["Write a poem about the internet."]]
|
146 |
)
|
|
|
148 |
# --- Launch App ---
|
149 |
if __name__ == "__main__":
|
150 |
print("Launching Gradio App...")
|
|
|
151 |
iface.launch()
|