looker01202 commited on
Commit
cef5bae
Β·
1 Parent(s): fc19e53

can use both Granite or Qwen depending on where running

Browse files
Files changed (1) hide show
  1. app.py +141 -104
app.py CHANGED
@@ -1,112 +1,149 @@
1
- from transformers import AutoModelForCausalLM, AutoTokenizer
 
2
  import gradio as gr
3
-
4
- from transformers import AutoTokenizer, AutoModelForCausalLM
5
  import torch
 
6
 
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
8
-
9
- primary_checkpoint = "ibm-granite/granite-3.2-2b-instruct"
10
- fallback_checkpoint = "Qwen/Qwen2.5-0.5B-Instruct"
11
-
12
- print(f"Attempting to load model on {device}...")
13
-
14
-
15
- #checkpoint = "HuggingFaceTB/SmolLM2-135M-Instruct"
16
- #checkpoint = "ibm-granite/granite-3.2-2b-instruct"
17
- #checkpoint = "Qwen/Qwen2.5-0.5B-Instruct"
18
-
19
- #device = "cpu" # "cuda" or "cpu"
20
- #tokenizer = AutoTokenizer.from_pretrained(checkpoint, force_download=True)
21
- #AutoTokenizer.from_pretrained(checkpoint, force_download=True)
22
-
23
- #print("This is the template being used:\n\n")
24
- #print(tokenizer.chat_template)
25
-
26
- #model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
27
-
28
- try:
29
- tokenizer = AutoTokenizer.from_pretrained(primary_checkpoint)
30
- model = AutoModelForCausalLM.from_pretrained(primary_checkpoint).to(device)
31
- model_name = primary_checkpoint
32
- print(f"βœ… Loaded primary model: {model_name}")
33
- except Exception as e:
34
- print(f"⚠️ Failed to load primary model: {e}")
35
- print(f"πŸ” Falling back to smaller model: {fallback_checkpoint}")
36
- tokenizer = AutoTokenizer.from_pretrained(fallback_checkpoint)
37
- model = AutoModelForCausalLM.from_pretrained(fallback_checkpoint).to(device)
38
- model_name = fallback_checkpoint
39
- print(f"βœ… Loaded fallback model: {model_name}")
40
-
41
- def chat(message, history):
42
- # history looks like:
43
- # [
44
- # {"role": "system", "content": ...}, # if you provided a system prompt
45
- # {"role": "user", "content": ...},
46
- # {"role": "assistant", "content": ...}
47
- # ...
48
- # ]
49
-
50
- # If we don't have any chat history object yet, then create one ready for sending to the model
51
- if not history:
52
- history = [{"role": "user", "content": message}]
53
- else:
54
- history = history + [{"role": "user", "content": message}]
55
-
56
- # Apply the chat template thereby converting `history` into a single text prompt:
57
- #input_text = tokenizer.apply_chat_template(history, tokenize=False)
58
- input_text = tokenizer.apply_chat_template(history, tokenize=False,add_generation_prompt=True)
59
-
60
- #print("printing templated chat (pre-tokenizes), ready for sending to the model\n")
61
- #print(input_text)
62
-
63
- # Tokenize the prompt ready for sending to the model
64
- inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
65
-
66
- # Send the tokenized prompt to the model and capture the reply in 'outputs'
67
- # outputs = model.generate(inputs, max_new_tokens=1024, temperature=0.1, top_p=0.9, do_sample=True)
68
- outputs = model.generate(inputs, max_new_tokens=1024)
69
-
70
-
71
- # de-tokenize the model the outputs
72
- decoded = tokenizer.decode(outputs[0])
73
-
74
- # Grab the final model response
75
- response = decoded.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0]
76
-
77
- #print("and here is the response\n")
78
-
79
- # print out the reply from the model to the last question we just added
80
- # print(tokenizer.decode(outputs[0]))
81
- #print(response)
82
-
83
- # Send this reponse back to gr.ChatInterface, which will display it to the user as the next assistant message.
84
- return response
85
-
86
- demo = gr.ChatInterface(
87
- fn=chat,
88
- type="messages",
89
- title="Hotel chat",
90
- chatbot=gr.Chatbot(height=300,type="messages"),
91
- textbox=gr.Textbox(placeholder="Ask me about the hotel", container=False, scale=7),
92
- #description="This is the description",
93
- #description="This is the description",
94
-
95
- description="""
96
- ### 🏨 Hotel Chatbot Demo
97
- Ask anything about your hotel stay β€” room availability, check-in times, amenities, and more.
98
-
99
- ---
100
 
101
- ⚠️ **Reminder:**
102
- When you're done demoing, **pause the Space** (top-right menu) to avoid GPU charges.
103
- """
 
 
 
 
 
104
 
105
- theme="ocean",
106
- examples=["Can you help me book a room?", "Do you have a pool?", "Can I check-in at midday?"],
107
- save_history=True,
108
- )
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  if __name__ == "__main__":
112
- demo.launch()
 
1
+ import os
2
+ import getpass
3
  import gradio as gr
 
 
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
 
7
+ # Detect execution environment: Spaces runs as user 'gradio'
8
+ is_space = (getpass.getuser() == "gradio")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ # Choose model checkpoints based on environment
11
+ if is_space:
12
+ primary_checkpoint = "ibm-granite/granite-3.3-8b-instruct"
13
+ fallback_checkpoint = "Qwen/Qwen2.5-0.5B-Instruct"
14
+ else:
15
+ # Local development: use smaller Qwen model only
16
+ primary_checkpoint = "Qwen/Qwen2.5-0.5B-Instruct"
17
+ fallback_checkpoint = None
18
 
19
+ # Device setup
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
21
 
22
+ # Load tokenizer and model (with fallback on Spaces)
23
+ def load_model():
24
+ if not is_space:
25
+ tokenizer = AutoTokenizer.from_pretrained(primary_checkpoint)
26
+ model = AutoModelForCausalLM.from_pretrained(primary_checkpoint).to(device)
27
+ return tokenizer, model, primary_checkpoint
28
+ try:
29
+ tokenizer = AutoTokenizer.from_pretrained(primary_checkpoint)
30
+ model = AutoModelForCausalLM.from_pretrained(primary_checkpoint).to(device)
31
+ return tokenizer, model, primary_checkpoint
32
+ except Exception:
33
+ # Fallback path on Spaces
34
+ tokenizer = AutoTokenizer.from_pretrained(fallback_checkpoint)
35
+ model = AutoModelForCausalLM.from_pretrained(fallback_checkpoint).to(device)
36
+ return tokenizer, model, fallback_checkpoint
37
+
38
+ tokenizer, model, model_name = load_model()
39
+
40
+ # Load hotel-specific documents from disk as (document_id, content) pairs
41
+ def load_hotel_docs(hotel_id: str):
42
+ path = os.path.join("knowledge", f"{hotel_id}.txt")
43
+ if not os.path.exists(path):
44
+ return []
45
+ content = open(path, "r", encoding="utf-8").read().strip()
46
+ # Use a single document; document_id can be hotel_id
47
+ return [(f"{hotel_id}-info", content)]
48
+
49
+ # Chat function integrating both local Qwen flow and IBM Granite RAG template with document roles
50
+ def chat(message, history, hotel_id):
51
+ if history is None:
52
+ history = []
53
+ # Append user message
54
+ history.append(("user", message))
55
+
56
+ # ==== Local development flow: simple chat via Qwen ====
57
+ if not is_space:
58
+ # Build message dict list from history tuples
59
+ msgs = [{"role": role, "content": content} for role, content in history]
60
+ # Apply Qwen's chat template
61
+ input_text = tokenizer.apply_chat_template(
62
+ msgs,
63
+ tokenize=False,
64
+ add_generation_prompt=True
65
+ )
66
+ print("printing templated chat (pre-tokenizes), ready for sending to the model\n")
67
+ print(input_text)
68
+
69
+ # Generate response
70
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
71
+ outputs = model.generate(
72
+ inputs,
73
+ max_new_tokens=1024,
74
+ do_sample=False
75
+ )
76
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=False)
77
+ print("and here is the de-tokenized response from the model\n")
78
+ print(decoded)
79
+ #response = decoded.split("<|assistant|>")[-1].strip()
80
+ response = decoded.split("<|im_start|>assistant\n")[-1].split("<|im_end|>")[0]
81
+ history.append(("assistant", f"{response}\n_(Model: {model_name})_"))
82
+ # Clear textbox by returning empty string as third output
83
+ return history, history, ""
84
+
85
+ # ==== Space production flow: IBM Granite RAG ====
86
+ # Prepare system prompt
87
+ system_prompt = (
88
+ "Knowledge Cutoff Date: April 2024. Today's Date: April 12, 2025. You are Granite, developed by IBM. Write the response to the user's input by strictly aligning with the facts in the provided documents. If the information needed to answer the question is not available in the documents, inform the user that the question cannot be answered based on the available data."
89
+ )
90
+ # Start building message list
91
+ messages = [{"role": "system", "content": system_prompt}]
92
+ # Inject each document with role 'document' and metadata
93
+ for doc_id, doc_content in load_hotel_docs(hotel_id):
94
+ messages.append({
95
+ "role": "document",
96
+ "content": doc_content,
97
+ "document_id": doc_id
98
+ })
99
+ # Finally add the user turn
100
+ messages.append({"role": "user", "content": message})
101
+
102
+ # Apply the model's chat template (IBM-trained template)
103
+ input_text = tokenizer.apply_chat_template(
104
+ messages,
105
+ tokenize=False,
106
+ add_generation_prompt=True
107
+ )
108
+
109
+ print("printing templated chat (pre-tokenizes), ready for sending to the model\n")
110
+ print(input_text)
111
+
112
+ # Tokenize, generate, and decode
113
+ inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
114
+ outputs = model.generate(
115
+ inputs,
116
+ max_new_tokens=1024,
117
+ do_sample=False
118
+ )
119
+ decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
120
+ # Extract the assistant's reply
121
+ response = decoded.split("<|assistant|>")[-1].strip()
122
+ history.append(("assistant", f"{response}\n_(Model: {model_name})_"))
123
+ # Clear textbox by returning empty string as third output
124
+ return history, history, ""
125
+
126
+ # Available hotels
127
+ hotel_ids = [
128
+ "cyprus-guesthouse-family",
129
+ "coastal-villa-family",
130
+ "village-inn-family"
131
+ ]
132
+
133
+ # Gradio interface setup
134
+ demo = gr.Blocks()
135
+ with demo:
136
+ gr.Markdown("### 🏨 Hotel Chatbot Demo")
137
+ with gr.Row():
138
+ hotel_selector = gr.Dropdown(hotel_ids, label="Choose a hotel", value=hotel_ids[0])
139
+ chatbot = gr.Chatbot()
140
+ msg = gr.Textbox(placeholder="Ask me about the hotel...", show_label=False)
141
+ msg.submit(
142
+ fn=chat,
143
+ inputs=[msg, chatbot, hotel_selector],
144
+ outputs=[chatbot, chatbot, msg]
145
+ )
146
+ gr.Markdown("⚠️ **Reminder:** Pause the Space when done to avoid GPU charges.")
147
 
148
  if __name__ == "__main__":
149
+ demo.launch()