Spaces:
Runtime error
Runtime error
looker01202
commited on
Commit
·
5573ab1
1
Parent(s):
0c79881
Gemini changes added 1
Browse files
app.py
CHANGED
@@ -23,14 +23,20 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
23 |
def load_model():
|
24 |
print(f"🔍 Loading model: {primary_checkpoint}")
|
25 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
tokenizer = AutoTokenizer.from_pretrained(
|
27 |
primary_checkpoint,
|
28 |
-
use_fast
|
29 |
)
|
30 |
model = AutoModelForCausalLM.from_pretrained(
|
31 |
primary_checkpoint,
|
32 |
-
|
33 |
-
low_cpu_mem_usage=True
|
34 |
).to(device)
|
35 |
print(f"✅ Loaded primary {primary_checkpoint}")
|
36 |
return tokenizer, model, primary_checkpoint
|
@@ -53,9 +59,15 @@ print(tokenizer.chat_template)
|
|
53 |
def load_hotel_docs(hotel_id):
|
54 |
path = os.path.join("knowledge", f"{hotel_id}.txt")
|
55 |
if not os.path.exists(path):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
56 |
return []
|
57 |
-
content = open(path, encoding="utf-8").read().strip()
|
58 |
-
return [(hotel_id, content)]
|
59 |
|
60 |
# Chat function
|
61 |
def chat(message, history, hotel_id):
|
@@ -69,7 +81,7 @@ def chat(message, history, hotel_id):
|
|
69 |
|
70 |
# Yield user message immediately
|
71 |
ui_history = [{"role": r, "content": c} for r, c in history_tuples]
|
72 |
-
yield ui_history, ""
|
73 |
|
74 |
# Local Qwen flow
|
75 |
if not is_space:
|
@@ -81,105 +93,146 @@ def chat(message, history, hotel_id):
|
|
81 |
add_generation_prompt=True
|
82 |
)
|
83 |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
84 |
-
|
85 |
with torch.no_grad():
|
86 |
outputs = model.generate(inputs, max_new_tokens=1024, do_sample=True)
|
87 |
-
|
88 |
decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
|
|
89 |
print(decoded)
|
90 |
-
|
91 |
-
|
92 |
-
response
|
93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
else:
|
95 |
-
#
|
96 |
-
|
97 |
-
"
|
98 |
-
"
|
99 |
-
"
|
100 |
-
"Greet guests politely, but only chit-chat when it helps answer hotel questions. "
|
101 |
-
"Answer using only facts from the documents; if unavailable, say you cannot answer."
|
102 |
)
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
messages.append({"role": "document", "content": doc_content, "document_id": doc_id})
|
|
|
106 |
# Include full history including the new user message
|
107 |
for role, content in history_tuples:
|
108 |
messages.append({"role": role, "content": content})
|
109 |
|
110 |
-
# Apply the template
|
111 |
input_text = tokenizer.apply_chat_template(
|
112 |
messages,
|
113 |
tokenize=False,
|
114 |
add_generation_prompt=True
|
115 |
)
|
116 |
-
|
117 |
-
|
118 |
-
print("printing templated chat\n")
|
119 |
print(input_text)
|
|
|
120 |
|
121 |
-
# Turn into tensors
|
122 |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
123 |
|
124 |
with torch.no_grad():
|
125 |
-
|
126 |
-
|
|
|
127 |
decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
128 |
|
129 |
-
|
130 |
-
print("printing reply from model\n")
|
131 |
print(decoded)
|
|
|
132 |
|
133 |
-
response
|
134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
-
# add the assistant reply to the running transcript
|
137 |
-
ui_history.append({"role": "assistant", "content": response})
|
138 |
-
|
139 |
# Final yield with assistant reply
|
140 |
-
yield ui_history, ""
|
141 |
|
142 |
# Available hotels
|
143 |
hotel_ids = ["cyprus-guesthouse-family", "coastal-villa-family", "village-inn-family"]
|
144 |
|
145 |
-
# Gradio UI
|
146 |
# Gradio UI
|
147 |
with gr.Blocks() as demo:
|
148 |
-
# ⬇️ NEW panel wrapper
|
149 |
with gr.Column(variant="panel"):
|
150 |
-
|
151 |
gr.Markdown("### 🏨 Multi‑Hotel Chatbot Demo")
|
152 |
gr.Markdown(f"**Running:** {model_name}")
|
153 |
|
154 |
hotel_selector = gr.Dropdown(
|
155 |
hotel_ids,
|
156 |
label="Hotel",
|
157 |
-
value=hotel_ids[0]
|
158 |
)
|
159 |
|
160 |
-
# Chat window in its own row so it stretches
|
161 |
with gr.Row():
|
162 |
-
|
|
|
163 |
|
164 |
msg = gr.Textbox(
|
165 |
show_label=False,
|
166 |
placeholder="Ask about the hotel..."
|
167 |
)
|
168 |
|
169 |
-
# Clear
|
170 |
-
gr.Button("Clear")
|
|
|
171 |
|
172 |
-
# Wire the textbox
|
173 |
msg.submit(
|
174 |
fn=chat,
|
175 |
inputs=[msg, chatbot, hotel_selector],
|
176 |
-
outputs=[chatbot, msg]
|
177 |
)
|
178 |
|
179 |
-
# Anything outside the column shows below the panel
|
180 |
gr.Markdown("⚠️ Pause the Space when done to avoid charges.")
|
181 |
|
182 |
-
# Enable streaming queue
|
183 |
demo.queue(default_concurrency_limit=2, max_size=32)
|
184 |
|
185 |
if __name__ == "__main__":
|
|
|
23 |
def load_model():
|
24 |
print(f"🔍 Loading model: {primary_checkpoint}")
|
25 |
try:
|
26 |
+
# Use optimized loading settings suitable for Granite
|
27 |
+
load_kwargs = {
|
28 |
+
"use_fast": True,
|
29 |
+
"torch_dtype": torch.float16,
|
30 |
+
"low_cpu_mem_usage": True
|
31 |
+
} if primary_checkpoint.startswith("ibm-granite") else {}
|
32 |
+
|
33 |
tokenizer = AutoTokenizer.from_pretrained(
|
34 |
primary_checkpoint,
|
35 |
+
**{k: v for k, v in load_kwargs.items() if k == 'use_fast'} # Only pass use_fast to tokenizer
|
36 |
)
|
37 |
model = AutoModelForCausalLM.from_pretrained(
|
38 |
primary_checkpoint,
|
39 |
+
**{k: v for k, v in load_kwargs.items() if k != 'use_fast'} # Pass other kwargs to model
|
|
|
40 |
).to(device)
|
41 |
print(f"✅ Loaded primary {primary_checkpoint}")
|
42 |
return tokenizer, model, primary_checkpoint
|
|
|
59 |
def load_hotel_docs(hotel_id):
|
60 |
path = os.path.join("knowledge", f"{hotel_id}.txt")
|
61 |
if not os.path.exists(path):
|
62 |
+
print(f"⚠️ Knowledge file not found: {path}")
|
63 |
+
return []
|
64 |
+
try:
|
65 |
+
with open(path, encoding="utf-8") as f:
|
66 |
+
content = f.read().strip()
|
67 |
+
return [(hotel_id, content)]
|
68 |
+
except Exception as e:
|
69 |
+
print(f"❌ Error reading knowledge file {path}: {e}")
|
70 |
return []
|
|
|
|
|
71 |
|
72 |
# Chat function
|
73 |
def chat(message, history, hotel_id):
|
|
|
81 |
|
82 |
# Yield user message immediately
|
83 |
ui_history = [{"role": r, "content": c} for r, c in history_tuples]
|
84 |
+
yield ui_history, "" # Update chat, clear textbox
|
85 |
|
86 |
# Local Qwen flow
|
87 |
if not is_space:
|
|
|
93 |
add_generation_prompt=True
|
94 |
)
|
95 |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
96 |
+
|
97 |
with torch.no_grad():
|
98 |
outputs = model.generate(inputs, max_new_tokens=1024, do_sample=True)
|
99 |
+
|
100 |
decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
101 |
+
print("--- Qwen Raw Output ---")
|
102 |
print(decoded)
|
103 |
+
print("-----------------------")
|
104 |
+
|
105 |
+
# Extract assistant response for Qwen
|
106 |
+
try:
|
107 |
+
response = decoded.split("<|im_start|>assistant")[-1]
|
108 |
+
response = response.split("<|im_end|>")[0].strip()
|
109 |
+
if not response: # Handle potential empty split
|
110 |
+
response = "Sorry, I encountered an issue generating a response."
|
111 |
+
except IndexError:
|
112 |
+
print("❌ Error splitting Qwen response.")
|
113 |
+
response = "Sorry, I couldn't parse the model's response."
|
114 |
+
|
115 |
+
# IBM Granite RAG flow (Space environment)
|
116 |
else:
|
117 |
+
# --- Start: Dynamic System Prompt Loading ---
|
118 |
+
default_system_prompt = (
|
119 |
+
"You are a helpful hotel assistant. Use only the provided documents to answer questions about the hotel. "
|
120 |
+
"Greet guests politely. If the information needed to answer the question is not available in the documents, "
|
121 |
+
"inform the user that the question cannot be answered based on the available data."
|
|
|
|
|
122 |
)
|
123 |
+
system_prompt_filename = f"{hotel_id}-system.txt"
|
124 |
+
system_prompt_path = os.path.join("knowledge", system_prompt_filename)
|
125 |
+
system_prompt_content = default_system_prompt # Start with default
|
126 |
+
|
127 |
+
if os.path.exists(system_prompt_path):
|
128 |
+
try:
|
129 |
+
with open(system_prompt_path, "r", encoding="utf-8") as f:
|
130 |
+
loaded_prompt = f.read().strip()
|
131 |
+
if loaded_prompt: # Use file content only if it's not empty
|
132 |
+
system_prompt_content = loaded_prompt
|
133 |
+
print(f"✅ Loaded system prompt from: {system_prompt_path}")
|
134 |
+
else:
|
135 |
+
print(f"⚠️ System prompt file '{system_prompt_path}' is empty. Using default.")
|
136 |
+
except Exception as e:
|
137 |
+
print(f"❌ Error reading system prompt file '{system_prompt_path}': {e}. Using default.")
|
138 |
+
else:
|
139 |
+
print(f"⚠️ System prompt file not found: '{system_prompt_path}'. Using default.")
|
140 |
+
# --- End: Dynamic System Prompt Loading ---
|
141 |
+
|
142 |
+
messages = [{"role": "system", "content": system_prompt_content}]
|
143 |
+
|
144 |
+
# Load and add hotel document(s)
|
145 |
+
hotel_docs = load_hotel_docs(hotel_id)
|
146 |
+
if not hotel_docs:
|
147 |
+
# If no knowledge doc found, inform user and stop
|
148 |
+
ui_history.append({"role": "assistant", "content": f"Sorry, I don't have specific information loaded for the hotel '{hotel_id}'."})
|
149 |
+
yield ui_history, "" # Update chat, keep textbox cleared
|
150 |
+
return # Exit the function early
|
151 |
+
|
152 |
+
for doc_id, doc_content in hotel_docs:
|
153 |
messages.append({"role": "document", "content": doc_content, "document_id": doc_id})
|
154 |
+
|
155 |
# Include full history including the new user message
|
156 |
for role, content in history_tuples:
|
157 |
messages.append({"role": role, "content": content})
|
158 |
|
159 |
+
# Apply the template
|
160 |
input_text = tokenizer.apply_chat_template(
|
161 |
messages,
|
162 |
tokenize=False,
|
163 |
add_generation_prompt=True
|
164 |
)
|
165 |
+
|
166 |
+
print("--- Granite Templated Input ---")
|
|
|
167 |
print(input_text)
|
168 |
+
print("-----------------------------")
|
169 |
|
|
|
170 |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
171 |
|
172 |
with torch.no_grad():
|
173 |
+
# Using do_sample=False for more deterministic RAG based on context
|
174 |
+
outputs = model.generate(inputs, max_new_tokens=1024, do_sample=False)
|
175 |
+
|
176 |
decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
|
177 |
|
178 |
+
print("--- Granite Raw Output ---")
|
|
|
179 |
print(decoded)
|
180 |
+
print("--------------------------")
|
181 |
|
182 |
+
# Extract assistant response for Granite
|
183 |
+
try:
|
184 |
+
response = decoded.split("<|start_of_role|>assistant<|end_of_role|>")[-1]
|
185 |
+
response = response.split("<|end_of_text|>")[0].strip()
|
186 |
+
if not response: # Handle potential empty split
|
187 |
+
response = "Sorry, I encountered an issue generating a response."
|
188 |
+
except IndexError:
|
189 |
+
print("❌ Error splitting Granite response.")
|
190 |
+
response = "Sorry, I couldn't parse the model's response."
|
191 |
+
|
192 |
+
# Add the final assistant reply to the UI history
|
193 |
+
ui_history.append({"role": "assistant", "content": response})
|
194 |
|
|
|
|
|
|
|
195 |
# Final yield with assistant reply
|
196 |
+
yield ui_history, "" # Update chat, keep textbox cleared
|
197 |
|
198 |
# Available hotels
|
199 |
hotel_ids = ["cyprus-guesthouse-family", "coastal-villa-family", "village-inn-family"]
|
200 |
|
|
|
201 |
# Gradio UI
|
202 |
with gr.Blocks() as demo:
|
|
|
203 |
with gr.Column(variant="panel"):
|
|
|
204 |
gr.Markdown("### 🏨 Multi‑Hotel Chatbot Demo")
|
205 |
gr.Markdown(f"**Running:** {model_name}")
|
206 |
|
207 |
hotel_selector = gr.Dropdown(
|
208 |
hotel_ids,
|
209 |
label="Hotel",
|
210 |
+
value=hotel_ids[0] # Default selection
|
211 |
)
|
212 |
|
|
|
213 |
with gr.Row():
|
214 |
+
# Use type="messages" for the dictionary format expected by the chat function
|
215 |
+
chatbot = gr.Chatbot(type="messages", label="Chat History")
|
216 |
|
217 |
msg = gr.Textbox(
|
218 |
show_label=False,
|
219 |
placeholder="Ask about the hotel..."
|
220 |
)
|
221 |
|
222 |
+
# Clear button needs to reset chatbot to None or empty list, and clear textbox
|
223 |
+
clear_btn = gr.Button("Clear")
|
224 |
+
clear_btn.click(lambda: (None, ""), None, [chatbot, msg]) # Reset chatbot history to None
|
225 |
|
226 |
+
# Wire the textbox submission
|
227 |
msg.submit(
|
228 |
fn=chat,
|
229 |
inputs=[msg, chatbot, hotel_selector],
|
230 |
+
outputs=[chatbot, msg] # chatbot updates, msg clears
|
231 |
)
|
232 |
|
|
|
233 |
gr.Markdown("⚠️ Pause the Space when done to avoid charges.")
|
234 |
|
235 |
+
# Enable streaming queue
|
236 |
demo.queue(default_concurrency_limit=2, max_size=32)
|
237 |
|
238 |
if __name__ == "__main__":
|