kimhyunwoo commited on
Commit
ddd5b6c
Β·
verified Β·
1 Parent(s): e761d96

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -53
app.py CHANGED
@@ -1,16 +1,21 @@
1
  import gradio as gr
2
  import torch
3
  import os
4
- from transformers import AutoTokenizer, __version__ as transformers_version # 버전 ν™•μΈμš© import μΆ”κ°€
5
- from optimum.onnxruntime import ORTModelForCausalLM, __version__ as optimum_version # 버전 ν™•μΈμš© import μΆ”κ°€
 
 
6
 
7
  # --- Configuration ---
8
- MODEL_ID = "onnx-community/gemma-3-1b-it-ONNX-GQA" # μ‚¬μš©μžκ°€ μ§€μ •ν•œ GQA λͺ¨λΈ
9
- ONNX_FILE_NAME = None # 파일λͺ… μžλ™ 감지 μ‹œλ„
10
 
11
  print(f"Using Transformers version: {transformers_version}")
12
- print(f"Using Optimum version: {optimum_version}")
13
- print(f"Using Gradio version: {gr.__version__}") # Gradio 버전 λ‘œκΉ…
 
 
 
14
 
15
  # --- Device Selection ---
16
  try:
@@ -43,25 +48,23 @@ try:
43
  model = ORTModelForCausalLM.from_pretrained(
44
  MODEL_ID,
45
  provider=provider,
46
- use_cache=True, # KV μΊμ‹œ μ‚¬μš©
47
- # use_io_binding=False # GPU μ‚¬μš© μ‹œ 문제 λ°œμƒν•˜λ©΄ False 둜 μ‹œλ„
48
  )
49
  print(f"ONNX Model '{MODEL_ID}' loaded successfully with provider '{provider}'.")
50
  model_loaded_successfully = True
51
 
52
  except ValueError as ve:
53
- # ValueError λŠ” λͺ¨λΈ νƒ€μž… 미지원 였λ₯˜μΌ κ°€λŠ₯성이 λ†’μŒ
54
  print(f"!!!!!!!!!!!!!! CRITICAL MODEL LOADING ERROR (ValueError) !!!!!!!!!!!!!!")
55
  print(f"Model: {MODEL_ID}")
56
  print(f"Error message: {ve}")
57
  print("This likely means the installed 'transformers' library version does NOT support the 'gemma3_text' architecture.")
58
  print("Ensure 'requirements.txt' specifies a recent version (e.g., transformers>=4.41.0) and the Space has been rebuilt/restarted.")
59
  print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
60
- # λͺ¨λΈ λ‘œλ”© μ‹€νŒ¨ μ‹œ μ‚¬μš©μžμ—κ²Œ λͺ…ν™•νžˆ μ•Œλ¦Ό
61
  model_loaded_successfully = False
62
 
63
  except Exception as e:
64
- # λ‹€λ₯Έ μ’…λ₯˜μ˜ μ˜ˆμ™Έ 처리 (λ©”λͺ¨λ¦¬ λΆ€μ‘±, λ„€νŠΈμ›Œν¬ λ“±)
65
  print(f"!!!!!!!!!!!!!! UNEXPECTED MODEL LOADING ERROR !!!!!!!!!!!!!!")
66
  print(f"Model: {MODEL_ID}")
67
  print(f"Error type: {type(e).__name__}")
@@ -73,33 +76,23 @@ except Exception as e:
73
  # --- Chat Function ---
74
  def chat_function(message: str, history: list):
75
  if not model_loaded_successfully or model is None or tokenizer is None:
76
- # λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨ μ‹œ 였λ₯˜ λ©”μ‹œμ§€ λ°˜ν™˜
77
  return "Error: The AI model is not loaded. Please check the application logs."
78
 
79
  try:
80
- # μ±„νŒ… 기둝을 messages ν˜•μ‹μœΌλ‘œ λ³€ν™˜
81
  chat_messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
82
  for user_msg, model_msg in history:
83
- # None κ°’ 체크 μΆ”κ°€ (Gradio 초기 μƒνƒœ λ“±μ—μ„œ λ°œμƒ κ°€λŠ₯)
84
- if user_msg:
85
- chat_messages.append({"role": "user", "content": user_msg})
86
- if model_msg:
87
- chat_messages.append({"role": "model", "content": model_msg})
88
- if message: # ν˜„μž¬ μ‚¬μš©μž λ©”μ‹œμ§€ μΆ”κ°€
89
- chat_messages.append({"role": "user", "content": message})
90
-
91
- # ν”„λ‘¬ν”„νŠΈ 생성 (apply_chat_template μ‹œλ„, μ‹€νŒ¨ μ‹œ μˆ˜λ™)
92
  prompt = ""
93
  try:
94
- prompt = tokenizer.apply_chat_template(
95
- chat_messages,
96
- tokenize=False,
97
- add_generation_prompt=True
98
- )
99
  except Exception as template_error:
100
  print(f"Warning: Failed to apply chat template ({template_error}). Using manual prompt construction.")
101
  prompt_parts = ["<start_of_turn>system\nYou are a helpful AI assistant.<end_of_turn>"]
102
- # history μ—μ„œ model λ©”μ‹œμ§€κ°€ None 일 수 μžˆμŒμ— 유의
103
  for user_msg, model_msg in history:
104
  if user_msg: prompt_parts.append(f"<start_of_turn>user\n{user_msg}<end_of_turn>")
105
  if model_msg: prompt_parts.append(f"<start_of_turn>model\n{model_msg}<end_of_turn>")
@@ -107,14 +100,12 @@ def chat_function(message: str, history: list):
107
  prompt_parts.append("<start_of_turn>model")
108
  prompt = "\n".join(prompt_parts)
109
 
110
- # print(f"--- PROMPT --- \n{prompt}\n--------------")
111
-
112
- # μž…λ ₯ 토큰화 및 λ””λ°”μ΄μŠ€ 이동
113
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
114
 
115
  # 응닡 생성
116
  print("Generating response...")
117
- with torch.no_grad(): # μΆ”λ‘  μ‹œ κ·Έλž˜λ””μ–ΈνŠΈ 계산 λΉ„ν™œμ„±ν™”
118
  outputs = model.generate(
119
  **inputs,
120
  max_new_tokens=512,
@@ -122,50 +113,34 @@ def chat_function(message: str, history: list):
122
  temperature=0.7,
123
  top_k=50,
124
  top_p=0.9,
125
- pad_token_id=tokenizer.eos_token_id # EOS 토큰을 νŒ¨λ”© ν† ν°μœΌλ‘œ μ‚¬μš©
126
  )
127
  print("Generation complete.")
128
 
129
- # λ””μ½”λ”© (μž…λ ₯ λΆ€λΆ„ μ œμ™Έ)
130
  input_token_len = inputs['input_ids'].shape[1]
131
  generated_tokens = outputs[0][input_token_len:]
132
  response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
133
-
134
- # ν›„μ²˜λ¦¬
135
  response = response.replace("<end_of_turn>", "").strip()
136
-
137
- # print(f"--- RESPONSE --- \n{response}\n--------------")
138
-
139
- # 빈 응닡 처리
140
  if not response:
141
  print("Warning: Generated empty response.")
142
  response = "Sorry, I couldn't generate a response for that."
143
-
144
  return response
145
 
146
  except Exception as e:
147
  print(f"!!!!!!!!!!!!!! Error during generation !!!!!!!!!!!!!!")
148
  print(f"Error type: {type(e).__name__}")
149
  print(f"Error message: {e}")
150
- print("Input message:", message)
151
- # traceback.print_exc() # ν•„μš”μ‹œ 상세 트레이슀백 좜λ ₯
152
  print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
153
  return f"Sorry, an error occurred during response generation. Please check logs."
154
 
155
-
156
- # --- Gradio Interface (μˆ˜μ •λ¨) ---
157
  print("Creating Gradio Interface...")
158
  iface = gr.ChatInterface(
159
- fn=chat_function, # λͺ¨λΈ λ‘œλ“œ μ‹€νŒ¨ μ‹œ chat_function λ‚΄λΆ€μ—μ„œ 처리
160
  title="AI Assistant (Gemma 3 1B ONNX-GQA)",
161
  description=f"Chat with {MODEL_ID}. Model loaded: {model_loaded_successfully}",
162
- # chatbot μœ„μ ―μ— type='messages' μΆ”κ°€
163
  chatbot=gr.Chatbot(height=600, type="messages", bubble_full_width=False),
164
- # μ§€μ›ν•˜μ§€ μ•ŠλŠ” λ²„νŠΌ 인자 제거
165
- # retry_btn=None, # 제거
166
- # undo_btn=None, # 제거
167
- # clear_btn=None, # 제거
168
- # submit_btn λŒ€μ‹  κΈ°λ³Έ λ²„νŠΌ μ‚¬μš©
169
  theme=gr.themes.Soft(),
170
  examples=[["Hello!"], ["Write a poem about the internet."]]
171
  )
@@ -173,5 +148,4 @@ iface = gr.ChatInterface(
173
  # --- Launch App ---
174
  if __name__ == "__main__":
175
  print("Launching Gradio App...")
176
- # λͺ¨λΈ λ‘œλ”© μ‹€νŒ¨ μ‹œμ—λ„ μΈν„°νŽ˜μ΄μŠ€λŠ” μ‹€ν–‰ν•˜λ˜, 였λ₯˜ λ©”μ‹œμ§€ ν‘œμ‹œ
177
  iface.launch()
 
1
  import gradio as gr
2
  import torch
3
  import os
4
+ # optimum.onnxruntime μ—μ„œ __version__ import 제거
5
+ from transformers import AutoTokenizer, __version__ as transformers_version
6
+ from optimum.onnxruntime import ORTModelForCausalLM
7
+ # import optimum # optimum 자체의 버전 확인 μ‹œλ„ (선택적)
8
 
9
  # --- Configuration ---
10
+ MODEL_ID = "onnx-community/gemma-3-1b-it-ONNX-GQA"
11
+ ONNX_FILE_NAME = None
12
 
13
  print(f"Using Transformers version: {transformers_version}")
14
+ # try:
15
+ # print(f"Using Optimum version: {optimum.__version__}") # λ‹€λ₯Έ λ°©λ²•μœΌλ‘œ 버전 확인 μ‹œλ„
16
+ # except AttributeError:
17
+ # print("Could not determine Optimum version automatically.")
18
+ print(f"Using Gradio version: {gr.__version__}")
19
 
20
  # --- Device Selection ---
21
  try:
 
48
  model = ORTModelForCausalLM.from_pretrained(
49
  MODEL_ID,
50
  provider=provider,
51
+ use_cache=True,
 
52
  )
53
  print(f"ONNX Model '{MODEL_ID}' loaded successfully with provider '{provider}'.")
54
  model_loaded_successfully = True
55
 
56
  except ValueError as ve:
57
+ # λͺ¨λΈ νƒ€μž… 미지원 였λ₯˜ 처리
58
  print(f"!!!!!!!!!!!!!! CRITICAL MODEL LOADING ERROR (ValueError) !!!!!!!!!!!!!!")
59
  print(f"Model: {MODEL_ID}")
60
  print(f"Error message: {ve}")
61
  print("This likely means the installed 'transformers' library version does NOT support the 'gemma3_text' architecture.")
62
  print("Ensure 'requirements.txt' specifies a recent version (e.g., transformers>=4.41.0) and the Space has been rebuilt/restarted.")
63
  print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
 
64
  model_loaded_successfully = False
65
 
66
  except Exception as e:
67
+ # λ‹€λ₯Έ μ˜ˆμ™Έ 처리
68
  print(f"!!!!!!!!!!!!!! UNEXPECTED MODEL LOADING ERROR !!!!!!!!!!!!!!")
69
  print(f"Model: {MODEL_ID}")
70
  print(f"Error type: {type(e).__name__}")
 
76
  # --- Chat Function ---
77
  def chat_function(message: str, history: list):
78
  if not model_loaded_successfully or model is None or tokenizer is None:
 
79
  return "Error: The AI model is not loaded. Please check the application logs."
80
 
81
  try:
82
+ # μ±„νŒ… 기둝 λ³€ν™˜
83
  chat_messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
84
  for user_msg, model_msg in history:
85
+ if user_msg: chat_messages.append({"role": "user", "content": user_msg})
86
+ if model_msg: chat_messages.append({"role": "model", "content": model_msg})
87
+ if message: chat_messages.append({"role": "user", "content": message})
88
+
89
+ # ν”„λ‘¬ν”„νŠΈ 생성
 
 
 
 
90
  prompt = ""
91
  try:
92
+ prompt = tokenizer.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True)
 
 
 
 
93
  except Exception as template_error:
94
  print(f"Warning: Failed to apply chat template ({template_error}). Using manual prompt construction.")
95
  prompt_parts = ["<start_of_turn>system\nYou are a helpful AI assistant.<end_of_turn>"]
 
96
  for user_msg, model_msg in history:
97
  if user_msg: prompt_parts.append(f"<start_of_turn>user\n{user_msg}<end_of_turn>")
98
  if model_msg: prompt_parts.append(f"<start_of_turn>model\n{model_msg}<end_of_turn>")
 
100
  prompt_parts.append("<start_of_turn>model")
101
  prompt = "\n".join(prompt_parts)
102
 
103
+ # μž…λ ₯ 토큰화
 
 
104
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
105
 
106
  # 응닡 생성
107
  print("Generating response...")
108
+ with torch.no_grad():
109
  outputs = model.generate(
110
  **inputs,
111
  max_new_tokens=512,
 
113
  temperature=0.7,
114
  top_k=50,
115
  top_p=0.9,
116
+ pad_token_id=tokenizer.eos_token_id
117
  )
118
  print("Generation complete.")
119
 
120
+ # λ””μ½”λ”©
121
  input_token_len = inputs['input_ids'].shape[1]
122
  generated_tokens = outputs[0][input_token_len:]
123
  response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
 
 
124
  response = response.replace("<end_of_turn>", "").strip()
 
 
 
 
125
  if not response:
126
  print("Warning: Generated empty response.")
127
  response = "Sorry, I couldn't generate a response for that."
 
128
  return response
129
 
130
  except Exception as e:
131
  print(f"!!!!!!!!!!!!!! Error during generation !!!!!!!!!!!!!!")
132
  print(f"Error type: {type(e).__name__}")
133
  print(f"Error message: {e}")
 
 
134
  print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
135
  return f"Sorry, an error occurred during response generation. Please check logs."
136
 
137
+ # --- Gradio Interface ---
 
138
  print("Creating Gradio Interface...")
139
  iface = gr.ChatInterface(
140
+ fn=chat_function,
141
  title="AI Assistant (Gemma 3 1B ONNX-GQA)",
142
  description=f"Chat with {MODEL_ID}. Model loaded: {model_loaded_successfully}",
 
143
  chatbot=gr.Chatbot(height=600, type="messages", bubble_full_width=False),
 
 
 
 
 
144
  theme=gr.themes.Soft(),
145
  examples=[["Hello!"], ["Write a poem about the internet."]]
146
  )
 
148
  # --- Launch App ---
149
  if __name__ == "__main__":
150
  print("Launching Gradio App...")
 
151
  iface.launch()