kimhyunwoo commited on
Commit
c5ec987
·
verified ·
1 Parent(s): a41650d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -76
app.py CHANGED
@@ -4,133 +4,115 @@ import torch
4
  import os
5
  import gradio as gr
6
 
7
- # --- 1. Authentication (Using Environment Variable - the ONLY correct way for Spaces) ---
8
 
9
- # Hugging Face Spaces CANNOT use interactive login. You MUST use an environment variable.
10
- # 1. Go to your Space's settings.
11
- # 2. Click on "Repository Secrets".
12
- # 3. Click "New Secret".
13
- # 4. Name the secret: HUGGING_FACE_HUB_TOKEN
14
- # 5. Paste your Hugging Face API token (with read access) as the value.
15
- # 6. Save the secret.
16
-
17
- # The login() call below will now automatically use the environment variable.
18
- login()
19
 
20
- # --- 2. Model and Tokenizer Setup (with comprehensive error handling) ---
21
 
22
  def load_model_and_tokenizer(model_name="google/gemma-3-1b-it"):
23
- """Loads the model and tokenizer, handling potential errors."""
24
  try:
25
- # Suppress unnecessary warning messages from transformers
26
  logging.set_verbosity_error()
27
-
28
  tokenizer = AutoTokenizer.from_pretrained(model_name)
29
  model = AutoModelForCausalLM.from_pretrained(
30
  model_name,
31
- device_map="auto", # Automatically use GPU if available, else CPU
32
- torch_dtype=torch.bfloat16, # Use bfloat16 for speed/memory if supported
33
- attn_implementation="flash_attention_2" # Use Flash Attention 2 if supported
34
  )
35
  return model, tokenizer
36
-
37
  except Exception as e:
38
- print(f"ERROR: Failed to load model or tokenizer: {e}")
39
- print("\nTroubleshooting Steps:")
40
- print("1. Ensure you have a Hugging Face account and have accepted the model's terms.")
41
- print("2. Verify your internet connection.")
42
- print("3. Double-check the model name: 'google/gemma-3-1b-it'")
43
- print("4. Ensure you are properly authenticated using a Repository Secret (see above).")
44
- print("5. If using a GPU, ensure your CUDA drivers and PyTorch are correctly installed.")
45
- # Instead of exiting, raise the exception to be caught by Gradio
46
- raise
47
-
48
- model, tokenizer = load_model_and_tokenizer()
49
-
50
 
51
- # --- 3. Chat Template Function (CRITICAL for conversational models) ---
52
 
53
  def apply_chat_template(messages, tokenizer):
54
- """Applies the appropriate chat template."""
55
  try:
56
  if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
57
  return tokenizer.apply_chat_template(
58
  messages, tokenize=False, add_generation_prompt=True
59
  )
60
  else:
61
- print("WARNING: Tokenizer does not have a defined chat_template. Using a fallback.")
62
  chat_template = "{% for message in messages %}" \
63
  "{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \
64
  "{% endfor %}" \
65
  "{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"
66
  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template)
67
-
68
  except Exception as e:
69
- print(f"ERROR: Failed to apply chat template: {e}")
70
- raise # Re-raise to be caught by Gradio
71
-
72
 
73
  # --- 4. Text Generation Function ---
74
 
75
  def generate_response(messages, model, tokenizer, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.2):
76
  """Generates a response."""
77
  prompt = apply_chat_template(messages, tokenizer)
78
-
79
  try:
80
  pipeline_instance = pipeline(
81
- "text-generation",
82
- model=model,
83
- tokenizer=tokenizer,
84
- torch_dtype=torch.bfloat16,
85
- device_map="auto",
86
  model_kwargs={"attn_implementation": "flash_attention_2"}
87
- )
88
-
89
  outputs = pipeline_instance(
90
- prompt,
91
- max_new_tokens=max_new_tokens,
92
- do_sample=True,
93
- temperature=temperature,
94
- top_k=top_k,
95
- top_p=top_p,
96
- repetition_penalty=repetition_penalty,
97
- pad_token_id=tokenizer.eos_token_id,
98
  )
 
 
 
 
99
 
100
- generated_text = outputs[0]["generated_text"][len(prompt):].strip()
101
- return generated_text
 
102
 
103
- except Exception as e:
104
- print(f"ERROR: Failed to generate response: {e}")
105
- raise # Re-raise the exception
106
 
 
 
107
 
108
- # --- 5. Gradio Interface ---
 
 
 
 
109
 
110
- def predict(message, history):
111
  if not history:
112
  history = []
113
- messages = []
114
- for user_msg, bot_response in history:
115
- messages.append({"role": "user", "content": user_msg})
116
- if bot_response: # Check if bot_response is not None
117
- messages.append({"role": "model", "content": bot_response})
118
  messages.append({"role": "user", "content": message})
119
 
120
  try:
121
- response = generate_response(messages, model, tokenizer)
122
- history.append((message, response))
123
- return "", history
124
  except Exception as e:
125
- # Catch any exceptions during generation and display in the UI
126
- return f"Error: {e}", history
127
-
128
 
129
  with gr.Blocks() as demo:
130
- chatbot = gr.Chatbot(label="Gemma Chatbot", height=500)
131
- msg = gr.Textbox(placeholder="Ask me anything!", container=False, scale=7)
132
- clear = gr.ClearButton([msg, chatbot])
 
 
 
 
 
 
133
 
134
- msg.submit(predict, [msg, chatbot], [msg, chatbot])
135
 
136
  demo.launch()
 
4
  import os
5
  import gradio as gr
6
 
7
+ # --- 1. Authentication (Using User-Provided Token) ---
8
 
9
+ def authenticate(token):
10
+ """Attempts to authenticate with the provided token."""
11
+ try:
12
+ login(token=token)
13
+ return True
14
+ except Exception as e:
15
+ print(f"Authentication failed: {e}")
16
+ return False
 
 
17
 
18
+ # --- 2. Model and Tokenizer Setup ---
19
 
20
  def load_model_and_tokenizer(model_name="google/gemma-3-1b-it"):
21
+ """Loads the model and tokenizer."""
22
  try:
 
23
  logging.set_verbosity_error()
 
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
25
  model = AutoModelForCausalLM.from_pretrained(
26
  model_name,
27
+ device_map="auto",
28
+ torch_dtype=torch.bfloat16,
29
+ attn_implementation="flash_attention_2"
30
  )
31
  return model, tokenizer
 
32
  except Exception as e:
33
+ print(f"ERROR: Failed to load model/tokenizer: {e}")
34
+ raise # Re-raise for Gradio
 
 
 
 
 
 
 
 
 
 
35
 
36
+ # --- 3. Chat Template Function ---
37
 
38
  def apply_chat_template(messages, tokenizer):
39
+ """Applies the chat template."""
40
  try:
41
  if hasattr(tokenizer, "chat_template") and tokenizer.chat_template:
42
  return tokenizer.apply_chat_template(
43
  messages, tokenize=False, add_generation_prompt=True
44
  )
45
  else:
46
+ print("WARNING: Tokenizer lacks chat_template. Using fallback.")
47
  chat_template = "{% for message in messages %}" \
48
  "{{ '<start_of_turn>' + message['role'] + '\n' + message['content'] + '<end_of_turn>\n' }}" \
49
  "{% endfor %}" \
50
  "{% if add_generation_prompt %}{{ '<start_of_turn>model\n' }}{% endif %}"
51
  return tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, chat_template=chat_template)
 
52
  except Exception as e:
53
+ print(f"ERROR: Chat template application failed: {e}")
54
+ raise
 
55
 
56
  # --- 4. Text Generation Function ---
57
 
58
  def generate_response(messages, model, tokenizer, max_new_tokens=256, temperature=0.7, top_k=50, top_p=0.95, repetition_penalty=1.2):
59
  """Generates a response."""
60
  prompt = apply_chat_template(messages, tokenizer)
 
61
  try:
62
  pipeline_instance = pipeline(
63
+ "text-generation", model=model, tokenizer=tokenizer,
64
+ torch_dtype=torch.bfloat16, device_map="auto",
 
 
 
65
  model_kwargs={"attn_implementation": "flash_attention_2"}
66
+ )
 
67
  outputs = pipeline_instance(
68
+ prompt, max_new_tokens=max_new_tokens, do_sample=True,
69
+ temperature=temperature, top_k=top_k, top_p=top_p,
70
+ repetition_penalty=repetition_penalty, pad_token_id=tokenizer.eos_token_id
 
 
 
 
 
71
  )
72
+ return outputs[0]["generated_text"][len(prompt):].strip()
73
+ except Exception as e:
74
+ print(f"ERROR: Response generation failed: {e}")
75
+ raise
76
 
77
+ # --- 5. Gradio Interface ---
78
+ model = None # Initialize model and tokenizer as global variables
79
+ tokenizer = None
80
 
81
+ def chat(token, message, history):
82
+ global model, tokenizer # Access the global model and tokenizer
 
83
 
84
+ if not authenticate(token):
85
+ return "Authentication failed. Please enter a valid Hugging Face token.", history
86
 
87
+ if model is None or tokenizer is None:
88
+ try:
89
+ model, tokenizer = load_model_and_tokenizer()
90
+ except Exception as e:
91
+ return f"Model loading error: {e}", history
92
 
 
93
  if not history:
94
  history = []
95
+ messages = [{"role": "user", "content": msg} for msg, _ in history]
96
+ messages.extend([{"role": "model", "content": resp} for _, resp in history if resp])
 
 
 
97
  messages.append({"role": "user", "content": message})
98
 
99
  try:
100
+ response = generate_response(messages, model, tokenizer)
101
+ history.append((message, response))
102
+ return "", history
103
  except Exception as e:
104
+ return f"Error during generation: {e}", history
 
 
105
 
106
  with gr.Blocks() as demo:
107
+ gr.Markdown("# Gemma Chatbot")
108
+ gr.Markdown("Enter your Hugging Face API token (read access required):")
109
+ token_input = gr.Textbox(label="Hugging Face Token", type="password") # Use type="password"
110
+ chatbot = gr.Chatbot(label="Chat", height=400)
111
+ msg_input = gr.Textbox(label="Message", placeholder="Ask me anything!")
112
+ clear_btn = gr.ClearButton([msg_input, chatbot])
113
+
114
+ msg_input.submit(chat, [token_input, msg_input, chatbot], [msg_input, chatbot])
115
+ clear_btn.click(lambda: (None, []), [], [msg_input, chatbot])
116
 
 
117
 
118
  demo.launch()