gggg / app.py
kimhyunwoo's picture
Update app.py
ddd5b6c verified
import gradio as gr
import torch
import os
# optimum.onnxruntime μ—μ„œ __version__ import 제거
from transformers import AutoTokenizer, __version__ as transformers_version
from optimum.onnxruntime import ORTModelForCausalLM
# import optimum # optimum 자체의 버전 확인 μ‹œλ„ (선택적)
# --- Configuration ---
MODEL_ID = "onnx-community/gemma-3-1b-it-ONNX-GQA"
ONNX_FILE_NAME = None
print(f"Using Transformers version: {transformers_version}")
# try:
# print(f"Using Optimum version: {optimum.__version__}") # λ‹€λ₯Έ λ°©λ²•μœΌλ‘œ 버전 확인 μ‹œλ„
# except AttributeError:
# print("Could not determine Optimum version automatically.")
print(f"Using Gradio version: {gr.__version__}")
# --- Device Selection ---
try:
if torch.cuda.is_available():
device = "cuda:0"
provider = "CUDAExecutionProvider"
print("Attempting to use GPU (CUDA).")
else:
device = "cpu"
provider = "CPUExecutionProvider"
print("Using CPU.")
except Exception as e:
print(f"Device detection error: {e}. Defaulting to CPU.")
device = "cpu"
provider = "CPUExecutionProvider"
# --- Model and Tokenizer Loading ---
model = None
tokenizer = None
model_loaded_successfully = False
print(f"Attempting to load model: {MODEL_ID}")
print(f"Using device: {device}, Execution Provider: {provider}")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
print("Tokenizer loaded successfully.")
# ONNX λͺ¨λΈ λ‘œλ“œ μ‹œλ„
model = ORTModelForCausalLM.from_pretrained(
MODEL_ID,
provider=provider,
use_cache=True,
)
print(f"ONNX Model '{MODEL_ID}' loaded successfully with provider '{provider}'.")
model_loaded_successfully = True
except ValueError as ve:
# λͺ¨λΈ νƒ€μž… 미지원 였λ₯˜ 처리
print(f"!!!!!!!!!!!!!! CRITICAL MODEL LOADING ERROR (ValueError) !!!!!!!!!!!!!!")
print(f"Model: {MODEL_ID}")
print(f"Error message: {ve}")
print("This likely means the installed 'transformers' library version does NOT support the 'gemma3_text' architecture.")
print("Ensure 'requirements.txt' specifies a recent version (e.g., transformers>=4.41.0) and the Space has been rebuilt/restarted.")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
model_loaded_successfully = False
except Exception as e:
# λ‹€λ₯Έ μ˜ˆμ™Έ 처리
print(f"!!!!!!!!!!!!!! UNEXPECTED MODEL LOADING ERROR !!!!!!!!!!!!!!")
print(f"Model: {MODEL_ID}")
print(f"Error type: {type(e).__name__}")
print(f"Error message: {e}")
print("Check Space resources (memory limits), network connection, or other dependencies.")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
model_loaded_successfully = False
# --- Chat Function ---
def chat_function(message: str, history: list):
if not model_loaded_successfully or model is None or tokenizer is None:
return "Error: The AI model is not loaded. Please check the application logs."
try:
# μ±„νŒ… 기둝 λ³€ν™˜
chat_messages = [{"role": "system", "content": "You are a helpful AI assistant."}]
for user_msg, model_msg in history:
if user_msg: chat_messages.append({"role": "user", "content": user_msg})
if model_msg: chat_messages.append({"role": "model", "content": model_msg})
if message: chat_messages.append({"role": "user", "content": message})
# ν”„λ‘¬ν”„νŠΈ 생성
prompt = ""
try:
prompt = tokenizer.apply_chat_template(chat_messages, tokenize=False, add_generation_prompt=True)
except Exception as template_error:
print(f"Warning: Failed to apply chat template ({template_error}). Using manual prompt construction.")
prompt_parts = ["<start_of_turn>system\nYou are a helpful AI assistant.<end_of_turn>"]
for user_msg, model_msg in history:
if user_msg: prompt_parts.append(f"<start_of_turn>user\n{user_msg}<end_of_turn>")
if model_msg: prompt_parts.append(f"<start_of_turn>model\n{model_msg}<end_of_turn>")
if message: prompt_parts.append(f"<start_of_turn>user\n{message}<end_of_turn>")
prompt_parts.append("<start_of_turn>model")
prompt = "\n".join(prompt_parts)
# μž…λ ₯ 토큰화
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# 응닡 생성
print("Generating response...")
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
do_sample=True,
temperature=0.7,
top_k=50,
top_p=0.9,
pad_token_id=tokenizer.eos_token_id
)
print("Generation complete.")
# λ””μ½”λ”©
input_token_len = inputs['input_ids'].shape[1]
generated_tokens = outputs[0][input_token_len:]
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
response = response.replace("<end_of_turn>", "").strip()
if not response:
print("Warning: Generated empty response.")
response = "Sorry, I couldn't generate a response for that."
return response
except Exception as e:
print(f"!!!!!!!!!!!!!! Error during generation !!!!!!!!!!!!!!")
print(f"Error type: {type(e).__name__}")
print(f"Error message: {e}")
print("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!")
return f"Sorry, an error occurred during response generation. Please check logs."
# --- Gradio Interface ---
print("Creating Gradio Interface...")
iface = gr.ChatInterface(
fn=chat_function,
title="AI Assistant (Gemma 3 1B ONNX-GQA)",
description=f"Chat with {MODEL_ID}. Model loaded: {model_loaded_successfully}",
chatbot=gr.Chatbot(height=600, type="messages", bubble_full_width=False),
theme=gr.themes.Soft(),
examples=[["Hello!"], ["Write a poem about the internet."]]
)
# --- Launch App ---
if __name__ == "__main__":
print("Launching Gradio App...")
iface.launch()