looker01202 commited on
Commit
5573ab1
·
1 Parent(s): 0c79881

Gemini changes added 1

Browse files
Files changed (1) hide show
  1. app.py +101 -48
app.py CHANGED
@@ -23,14 +23,20 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
23
  def load_model():
24
  print(f"🔍 Loading model: {primary_checkpoint}")
25
  try:
 
 
 
 
 
 
 
26
  tokenizer = AutoTokenizer.from_pretrained(
27
  primary_checkpoint,
28
- use_fast=True
29
  )
30
  model = AutoModelForCausalLM.from_pretrained(
31
  primary_checkpoint,
32
- torch_dtype=torch.float16,
33
- low_cpu_mem_usage=True
34
  ).to(device)
35
  print(f"✅ Loaded primary {primary_checkpoint}")
36
  return tokenizer, model, primary_checkpoint
@@ -53,9 +59,15 @@ print(tokenizer.chat_template)
53
  def load_hotel_docs(hotel_id):
54
  path = os.path.join("knowledge", f"{hotel_id}.txt")
55
  if not os.path.exists(path):
 
 
 
 
 
 
 
 
56
  return []
57
- content = open(path, encoding="utf-8").read().strip()
58
- return [(hotel_id, content)]
59
 
60
  # Chat function
61
  def chat(message, history, hotel_id):
@@ -69,7 +81,7 @@ def chat(message, history, hotel_id):
69
 
70
  # Yield user message immediately
71
  ui_history = [{"role": r, "content": c} for r, c in history_tuples]
72
- yield ui_history, ""
73
 
74
  # Local Qwen flow
75
  if not is_space:
@@ -81,105 +93,146 @@ def chat(message, history, hotel_id):
81
  add_generation_prompt=True
82
  )
83
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
84
-
85
  with torch.no_grad():
86
  outputs = model.generate(inputs, max_new_tokens=1024, do_sample=True)
87
-
88
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
 
89
  print(decoded)
90
-
91
- # Extract assistant response
92
- response = decoded.split("<|im_start|>assistant")[-1]
93
- response = response.split("<|im_end|>")[0].strip()
 
 
 
 
 
 
 
 
 
94
  else:
95
- # IBM Granite RAG flow
96
- system_prompt = (
97
- "Knowledge Cutoff Date: April 2024. Today's Date: April 12, 2025. "
98
- "You are Alexander, the front desk assistant at Family Village Inn in Cyprus. "
99
- "You only know what's in the provided documents. "
100
- "Greet guests politely, but only chit-chat when it helps answer hotel questions. "
101
- "Answer using only facts from the documents; if unavailable, say you cannot answer."
102
  )
103
- messages = [{"role": "system", "content": system_prompt}]
104
- for doc_id, doc_content in load_hotel_docs(hotel_id):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  messages.append({"role": "document", "content": doc_content, "document_id": doc_id})
 
106
  # Include full history including the new user message
107
  for role, content in history_tuples:
108
  messages.append({"role": role, "content": content})
109
 
110
- # Apply the template to the chat dictionary to create a templated string which can be tokenized
111
  input_text = tokenizer.apply_chat_template(
112
  messages,
113
  tokenize=False,
114
  add_generation_prompt=True
115
  )
116
-
117
- # Print the templated string
118
- print("printing templated chat\n")
119
  print(input_text)
 
120
 
121
- # Turn into tensors
122
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
123
 
124
  with torch.no_grad():
125
- outputs = model.generate(inputs, max_new_tokens=1024, do_sample=True)
126
-
 
127
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
128
 
129
- # Print the templated string
130
- print("printing reply from model\n")
131
  print(decoded)
 
132
 
133
- response = decoded.split("<|start_of_role|>assistant<|end_of_role|>")[-1]
134
- response = response.split("<|end_of_text|>")[0].strip()
 
 
 
 
 
 
 
 
 
 
135
 
136
- # add the assistant reply to the running transcript
137
- ui_history.append({"role": "assistant", "content": response})
138
-
139
  # Final yield with assistant reply
140
- yield ui_history, ""
141
 
142
  # Available hotels
143
  hotel_ids = ["cyprus-guesthouse-family", "coastal-villa-family", "village-inn-family"]
144
 
145
- # Gradio UI
146
  # Gradio UI
147
  with gr.Blocks() as demo:
148
- # ⬇️ NEW panel wrapper
149
  with gr.Column(variant="panel"):
150
-
151
  gr.Markdown("### 🏨 Multi‑Hotel Chatbot Demo")
152
  gr.Markdown(f"**Running:** {model_name}")
153
 
154
  hotel_selector = gr.Dropdown(
155
  hotel_ids,
156
  label="Hotel",
157
- value=hotel_ids[0]
158
  )
159
 
160
- # Chat window in its own row so it stretches
161
  with gr.Row():
162
- chatbot = gr.Chatbot(type="messages")
 
163
 
164
  msg = gr.Textbox(
165
  show_label=False,
166
  placeholder="Ask about the hotel..."
167
  )
168
 
169
- # Clear‑history button
170
- gr.Button("Clear").click(lambda: ([], ""), None, [chatbot, msg])
 
171
 
172
- # Wire the textbox to the chat function
173
  msg.submit(
174
  fn=chat,
175
  inputs=[msg, chatbot, hotel_selector],
176
- outputs=[chatbot, msg]
177
  )
178
 
179
- # Anything outside the column shows below the panel
180
  gr.Markdown("⚠️ Pause the Space when done to avoid charges.")
181
 
182
- # Enable streaming queue for generator-based chat
183
  demo.queue(default_concurrency_limit=2, max_size=32)
184
 
185
  if __name__ == "__main__":
 
23
  def load_model():
24
  print(f"🔍 Loading model: {primary_checkpoint}")
25
  try:
26
+ # Use optimized loading settings suitable for Granite
27
+ load_kwargs = {
28
+ "use_fast": True,
29
+ "torch_dtype": torch.float16,
30
+ "low_cpu_mem_usage": True
31
+ } if primary_checkpoint.startswith("ibm-granite") else {}
32
+
33
  tokenizer = AutoTokenizer.from_pretrained(
34
  primary_checkpoint,
35
+ **{k: v for k, v in load_kwargs.items() if k == 'use_fast'} # Only pass use_fast to tokenizer
36
  )
37
  model = AutoModelForCausalLM.from_pretrained(
38
  primary_checkpoint,
39
+ **{k: v for k, v in load_kwargs.items() if k != 'use_fast'} # Pass other kwargs to model
 
40
  ).to(device)
41
  print(f"✅ Loaded primary {primary_checkpoint}")
42
  return tokenizer, model, primary_checkpoint
 
59
  def load_hotel_docs(hotel_id):
60
  path = os.path.join("knowledge", f"{hotel_id}.txt")
61
  if not os.path.exists(path):
62
+ print(f"⚠️ Knowledge file not found: {path}")
63
+ return []
64
+ try:
65
+ with open(path, encoding="utf-8") as f:
66
+ content = f.read().strip()
67
+ return [(hotel_id, content)]
68
+ except Exception as e:
69
+ print(f"❌ Error reading knowledge file {path}: {e}")
70
  return []
 
 
71
 
72
  # Chat function
73
  def chat(message, history, hotel_id):
 
81
 
82
  # Yield user message immediately
83
  ui_history = [{"role": r, "content": c} for r, c in history_tuples]
84
+ yield ui_history, "" # Update chat, clear textbox
85
 
86
  # Local Qwen flow
87
  if not is_space:
 
93
  add_generation_prompt=True
94
  )
95
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
96
+
97
  with torch.no_grad():
98
  outputs = model.generate(inputs, max_new_tokens=1024, do_sample=True)
99
+
100
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
101
+ print("--- Qwen Raw Output ---")
102
  print(decoded)
103
+ print("-----------------------")
104
+
105
+ # Extract assistant response for Qwen
106
+ try:
107
+ response = decoded.split("<|im_start|>assistant")[-1]
108
+ response = response.split("<|im_end|>")[0].strip()
109
+ if not response: # Handle potential empty split
110
+ response = "Sorry, I encountered an issue generating a response."
111
+ except IndexError:
112
+ print("❌ Error splitting Qwen response.")
113
+ response = "Sorry, I couldn't parse the model's response."
114
+
115
+ # IBM Granite RAG flow (Space environment)
116
  else:
117
+ # --- Start: Dynamic System Prompt Loading ---
118
+ default_system_prompt = (
119
+ "You are a helpful hotel assistant. Use only the provided documents to answer questions about the hotel. "
120
+ "Greet guests politely. If the information needed to answer the question is not available in the documents, "
121
+ "inform the user that the question cannot be answered based on the available data."
 
 
122
  )
123
+ system_prompt_filename = f"{hotel_id}-system.txt"
124
+ system_prompt_path = os.path.join("knowledge", system_prompt_filename)
125
+ system_prompt_content = default_system_prompt # Start with default
126
+
127
+ if os.path.exists(system_prompt_path):
128
+ try:
129
+ with open(system_prompt_path, "r", encoding="utf-8") as f:
130
+ loaded_prompt = f.read().strip()
131
+ if loaded_prompt: # Use file content only if it's not empty
132
+ system_prompt_content = loaded_prompt
133
+ print(f"✅ Loaded system prompt from: {system_prompt_path}")
134
+ else:
135
+ print(f"⚠️ System prompt file '{system_prompt_path}' is empty. Using default.")
136
+ except Exception as e:
137
+ print(f"❌ Error reading system prompt file '{system_prompt_path}': {e}. Using default.")
138
+ else:
139
+ print(f"⚠️ System prompt file not found: '{system_prompt_path}'. Using default.")
140
+ # --- End: Dynamic System Prompt Loading ---
141
+
142
+ messages = [{"role": "system", "content": system_prompt_content}]
143
+
144
+ # Load and add hotel document(s)
145
+ hotel_docs = load_hotel_docs(hotel_id)
146
+ if not hotel_docs:
147
+ # If no knowledge doc found, inform user and stop
148
+ ui_history.append({"role": "assistant", "content": f"Sorry, I don't have specific information loaded for the hotel '{hotel_id}'."})
149
+ yield ui_history, "" # Update chat, keep textbox cleared
150
+ return # Exit the function early
151
+
152
+ for doc_id, doc_content in hotel_docs:
153
  messages.append({"role": "document", "content": doc_content, "document_id": doc_id})
154
+
155
  # Include full history including the new user message
156
  for role, content in history_tuples:
157
  messages.append({"role": role, "content": content})
158
 
159
+ # Apply the template
160
  input_text = tokenizer.apply_chat_template(
161
  messages,
162
  tokenize=False,
163
  add_generation_prompt=True
164
  )
165
+
166
+ print("--- Granite Templated Input ---")
 
167
  print(input_text)
168
+ print("-----------------------------")
169
 
 
170
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
171
 
172
  with torch.no_grad():
173
+ # Using do_sample=False for more deterministic RAG based on context
174
+ outputs = model.generate(inputs, max_new_tokens=1024, do_sample=False)
175
+
176
  decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
177
 
178
+ print("--- Granite Raw Output ---")
 
179
  print(decoded)
180
+ print("--------------------------")
181
 
182
+ # Extract assistant response for Granite
183
+ try:
184
+ response = decoded.split("<|start_of_role|>assistant<|end_of_role|>")[-1]
185
+ response = response.split("<|end_of_text|>")[0].strip()
186
+ if not response: # Handle potential empty split
187
+ response = "Sorry, I encountered an issue generating a response."
188
+ except IndexError:
189
+ print("❌ Error splitting Granite response.")
190
+ response = "Sorry, I couldn't parse the model's response."
191
+
192
+ # Add the final assistant reply to the UI history
193
+ ui_history.append({"role": "assistant", "content": response})
194
 
 
 
 
195
  # Final yield with assistant reply
196
+ yield ui_history, "" # Update chat, keep textbox cleared
197
 
198
  # Available hotels
199
  hotel_ids = ["cyprus-guesthouse-family", "coastal-villa-family", "village-inn-family"]
200
 
 
201
  # Gradio UI
202
  with gr.Blocks() as demo:
 
203
  with gr.Column(variant="panel"):
 
204
  gr.Markdown("### 🏨 Multi‑Hotel Chatbot Demo")
205
  gr.Markdown(f"**Running:** {model_name}")
206
 
207
  hotel_selector = gr.Dropdown(
208
  hotel_ids,
209
  label="Hotel",
210
+ value=hotel_ids[0] # Default selection
211
  )
212
 
 
213
  with gr.Row():
214
+ # Use type="messages" for the dictionary format expected by the chat function
215
+ chatbot = gr.Chatbot(type="messages", label="Chat History")
216
 
217
  msg = gr.Textbox(
218
  show_label=False,
219
  placeholder="Ask about the hotel..."
220
  )
221
 
222
+ # Clear button needs to reset chatbot to None or empty list, and clear textbox
223
+ clear_btn = gr.Button("Clear")
224
+ clear_btn.click(lambda: (None, ""), None, [chatbot, msg]) # Reset chatbot history to None
225
 
226
+ # Wire the textbox submission
227
  msg.submit(
228
  fn=chat,
229
  inputs=[msg, chatbot, hotel_selector],
230
+ outputs=[chatbot, msg] # chatbot updates, msg clears
231
  )
232
 
 
233
  gr.Markdown("⚠️ Pause the Space when done to avoid charges.")
234
 
235
+ # Enable streaming queue
236
  demo.queue(default_concurrency_limit=2, max_size=32)
237
 
238
  if __name__ == "__main__":