Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,112 +1,25 @@
|
|
1 |
import spaces
|
2 |
-
import os
|
3 |
-
import uuid
|
4 |
-
import time
|
5 |
-
import json
|
6 |
-
import torch
|
7 |
-
from datetime import datetime, timedelta
|
8 |
-
from threading import Thread
|
9 |
-
from pathlib import Path
|
10 |
-
|
11 |
-
# Gradio and HuggingFace imports
|
12 |
-
import gradio as gr
|
13 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
14 |
-
|
15 |
-
from
|
|
|
|
|
16 |
|
17 |
-
# Model configuration
|
18 |
checkpoint = "WillHeld/soft-raccoon"
|
19 |
-
|
20 |
-
# Set device based on availability
|
21 |
-
if torch.cuda.is_available():
|
22 |
-
device = "cuda"
|
23 |
-
else:
|
24 |
-
device = "cpu"
|
25 |
-
print("CUDA not available, using CPU instead. This will be much slower.")
|
26 |
-
|
27 |
-
# Dataset configuration
|
28 |
-
DATASET_NAME = "WillHeld/soft-raccoon-conversations" # Change to your username
|
29 |
-
SAVE_INTERVAL_MINUTES = 5 # Save data every 5 minutes
|
30 |
-
last_save_time = datetime.now()
|
31 |
-
|
32 |
-
# Initialize model and tokenizer
|
33 |
-
print(f"Loading model from {checkpoint}...")
|
34 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
35 |
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
|
36 |
|
37 |
-
# Data storage
|
38 |
-
conversations = []
|
39 |
-
|
40 |
-
# Hugging Face authentication
|
41 |
-
# Uncomment this line to login with your token
|
42 |
-
# login(token=os.environ.get("HF_TOKEN"))
|
43 |
-
|
44 |
-
def save_to_dataset():
|
45 |
-
"""Save the current conversations to a HuggingFace dataset"""
|
46 |
-
if not conversations:
|
47 |
-
return None, f"No conversations to save. Last attempt: {datetime.now().strftime('%H:%M:%S')}"
|
48 |
-
|
49 |
-
# Convert conversations to dataset format
|
50 |
-
dataset_dict = {
|
51 |
-
"conversation_id": [],
|
52 |
-
"timestamp": [],
|
53 |
-
"messages": [],
|
54 |
-
"metadata": []
|
55 |
-
}
|
56 |
-
|
57 |
-
for conv in conversations:
|
58 |
-
dataset_dict["conversation_id"].append(conv["conversation_id"])
|
59 |
-
dataset_dict["timestamp"].append(conv["timestamp"])
|
60 |
-
dataset_dict["messages"].append(json.dumps(conv["messages"]))
|
61 |
-
dataset_dict["metadata"].append(json.dumps(conv["metadata"]))
|
62 |
-
|
63 |
-
# Create dataset
|
64 |
-
dataset = Dataset.from_dict(dataset_dict)
|
65 |
-
|
66 |
-
try:
|
67 |
-
# Push to hub
|
68 |
-
dataset.push_to_hub(DATASET_NAME)
|
69 |
-
status_msg = f"Successfully saved {len(conversations)} conversations to {DATASET_NAME}"
|
70 |
-
print(status_msg)
|
71 |
-
except Exception as e:
|
72 |
-
# Save locally as fallback
|
73 |
-
local_path = f"local_dataset_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
74 |
-
dataset.save_to_disk(local_path)
|
75 |
-
status_msg = f"Error pushing to hub: {str(e)}. Saved locally to '{local_path}'"
|
76 |
-
print(status_msg)
|
77 |
-
|
78 |
-
return dataset, status_msg
|
79 |
-
|
80 |
@spaces.GPU(duration=120)
|
81 |
-
def
|
82 |
-
"""
|
83 |
-
|
84 |
-
if conversation_id is None:
|
85 |
-
conversation_id = str(uuid.uuid4())
|
86 |
-
chat_model.conversation_id = conversation_id
|
87 |
-
|
88 |
-
# Format chat history for the model
|
89 |
-
formatted_history = []
|
90 |
-
for h in history:
|
91 |
-
formatted_history.append({"role": "user", "content": h["content"] if h["role"] == "user" else ""})
|
92 |
-
if h["role"] == "assistant":
|
93 |
-
formatted_history.append({"role": "assistant", "content": h["content"]})
|
94 |
-
|
95 |
-
# Add the current message
|
96 |
-
formatted_history.append({"role": "user", "content": message})
|
97 |
-
|
98 |
-
# Prepare input for the model
|
99 |
-
input_text = tokenizer.apply_chat_template(
|
100 |
-
formatted_history,
|
101 |
-
tokenize=False,
|
102 |
-
add_generation_prompt=True
|
103 |
-
)
|
104 |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
105 |
|
106 |
-
#
|
107 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
108 |
|
109 |
-
#
|
110 |
generation_kwargs = {
|
111 |
"input_ids": inputs,
|
112 |
"max_new_tokens": 1024,
|
@@ -114,210 +27,73 @@ def chat_model(message, history, temperature=0.7, top_p=0.9):
|
|
114 |
"top_p": float(top_p),
|
115 |
"do_sample": True,
|
116 |
"streamer": streamer,
|
|
|
117 |
}
|
118 |
|
119 |
-
#
|
120 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
121 |
thread.start()
|
122 |
|
123 |
-
#
|
124 |
partial_text = ""
|
125 |
-
|
126 |
-
# Yield partial text as it's generated
|
127 |
for new_text in streamer:
|
128 |
partial_text += new_text
|
129 |
yield partial_text
|
130 |
-
|
131 |
-
# Store conversation data in the global conversations list
|
132 |
-
formatted_history.append({"role": "assistant", "content": partial_text})
|
133 |
-
|
134 |
-
# Find existing conversation or create new one
|
135 |
-
existing_conv = next((c for c in conversations if c["conversation_id"] == conversation_id), None)
|
136 |
-
|
137 |
-
# Update or create conversation record
|
138 |
-
current_time = datetime.now().isoformat()
|
139 |
-
if existing_conv:
|
140 |
-
# Update existing conversation
|
141 |
-
existing_conv["messages"] = formatted_history
|
142 |
-
existing_conv["metadata"]["last_updated"] = current_time
|
143 |
-
existing_conv["metadata"]["temperature"] = temperature
|
144 |
-
existing_conv["metadata"]["top_p"] = top_p
|
145 |
-
else:
|
146 |
-
# Create new conversation record
|
147 |
-
conversations.append({
|
148 |
-
"conversation_id": conversation_id,
|
149 |
-
"timestamp": current_time,
|
150 |
-
"messages": formatted_history,
|
151 |
-
"metadata": {
|
152 |
-
"model": checkpoint,
|
153 |
-
"temperature": temperature,
|
154 |
-
"top_p": top_p,
|
155 |
-
"last_updated": current_time
|
156 |
-
}
|
157 |
-
})
|
158 |
-
|
159 |
-
# Check if it's time to save based on elapsed time
|
160 |
-
global last_save_time
|
161 |
-
current_time_dt = datetime.now()
|
162 |
-
if current_time_dt - last_save_time > timedelta(minutes=SAVE_INTERVAL_MINUTES):
|
163 |
-
save_to_dataset()
|
164 |
-
last_save_time = current_time_dt
|
165 |
-
|
166 |
-
def save_dataset_manually():
|
167 |
-
"""Manually trigger dataset save and return status"""
|
168 |
-
_, status = save_to_dataset()
|
169 |
-
return status
|
170 |
-
|
171 |
-
def get_stats():
|
172 |
-
"""Get current stats about conversations and saving"""
|
173 |
-
mins_until_save = SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60
|
174 |
-
if mins_until_save < 0:
|
175 |
-
mins_until_save = 0
|
176 |
-
|
177 |
-
return {
|
178 |
-
"conversation_count": len(conversations),
|
179 |
-
"next_save": f"In {mins_until_save} minutes",
|
180 |
-
"last_save": last_save_time.strftime('%H:%M:%S'),
|
181 |
-
"dataset_name": DATASET_NAME
|
182 |
-
}
|
183 |
-
|
184 |
-
def update_stats_event():
|
185 |
-
"""Update the stats UI elements"""
|
186 |
-
stats = get_stats()
|
187 |
-
return [
|
188 |
-
stats["conversation_count"],
|
189 |
-
stats["next_save"],
|
190 |
-
stats["last_save"],
|
191 |
-
stats["dataset_name"]
|
192 |
-
]
|
193 |
|
194 |
-
#
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui"]
|
200 |
-
).set(
|
201 |
-
button_primary_background_fill="#8C1515",
|
202 |
-
button_primary_background_fill_hover="#771212",
|
203 |
-
button_primary_text_color="white",
|
204 |
-
slider_color="#8C1515",
|
205 |
-
block_title_text_color="#8C1515",
|
206 |
-
block_label_text_color="#4D4F53",
|
207 |
-
input_border_color_focus="#8C1515",
|
208 |
-
checkbox_background_color_selected="#8C1515",
|
209 |
-
checkbox_border_color_selected="#8C1515",
|
210 |
-
button_secondary_border_color="#4D4F53",
|
211 |
-
block_title_background_fill="#f5f5f5",
|
212 |
-
block_label_background_fill="#f9f9f9"
|
213 |
-
)
|
214 |
|
215 |
-
|
216 |
-
css = """
|
217 |
-
.gradio-container {
|
218 |
-
font-family: 'Source Sans Pro', sans-serif !important;
|
219 |
-
}
|
220 |
-
.footer {
|
221 |
-
color: #4D4F53 !important;
|
222 |
-
font-size: 0.85em !important;
|
223 |
-
}
|
224 |
-
"""
|
225 |
-
|
226 |
-
# Set up the Gradio app with Blocks for more control
|
227 |
-
with gr.Blocks(theme=theme, title="Stanford Soft Raccoon Chat", css=css) as demo:
|
228 |
-
# Create a timer for periodic updates
|
229 |
-
timer = gr.Timer(30, update_stats_event) # Update every 30 seconds
|
230 |
-
|
231 |
with gr.Row():
|
232 |
with gr.Column(scale=3):
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
placeholder="<strong>Soft Raccoon AI Assistant</strong><br>Ask me anything!"
|
239 |
-
)
|
240 |
-
|
241 |
-
# Create sliders for temperature and top_p
|
242 |
-
with gr.Accordion("Generation Parameters", open=False):
|
243 |
-
temperature = gr.Slider(
|
244 |
-
minimum=0.1,
|
245 |
-
maximum=2.0,
|
246 |
-
value=0.7,
|
247 |
-
step=0.1,
|
248 |
-
label="Temperature"
|
249 |
-
)
|
250 |
-
top_p = gr.Slider(
|
251 |
-
minimum=0.1,
|
252 |
-
maximum=1.0,
|
253 |
-
value=0.9,
|
254 |
-
step=0.05,
|
255 |
-
label="Top-P"
|
256 |
-
)
|
257 |
-
|
258 |
-
# Create the ChatInterface
|
259 |
-
chat_interface = gr.ChatInterface(
|
260 |
-
fn=chat_model,
|
261 |
-
chatbot=chatbot,
|
262 |
-
additional_inputs=[temperature, top_p],
|
263 |
-
type="messages", # This is important for compatibility
|
264 |
-
title="Stanford Soft Raccoon Chat",
|
265 |
-
description="AI assistant powered by the Soft Raccoon language model",
|
266 |
-
examples=[
|
267 |
-
["Tell me about Stanford University", 0.7, 0.9],
|
268 |
-
["How can I learn about artificial intelligence?", 0.8, 0.95],
|
269 |
-
["What's your favorite book?", 0.6, 0.85]
|
270 |
],
|
271 |
-
|
272 |
)
|
273 |
|
274 |
with gr.Column(scale=1):
|
275 |
-
|
276 |
-
gr.Markdown("### Dataset Controls")
|
277 |
-
save_button = gr.Button("Save conversations now", variant="secondary")
|
278 |
-
status_output = gr.Textbox(label="Save Status", interactive=False)
|
279 |
-
|
280 |
-
with gr.Row():
|
281 |
-
convo_count = gr.Number(label="Total Conversations", interactive=False)
|
282 |
-
next_save = gr.Textbox(label="Next Auto-Save", interactive=False)
|
283 |
-
|
284 |
-
last_save_time_display = gr.Textbox(label="Last Save Time", interactive=False)
|
285 |
-
dataset_name_display = gr.Textbox(label="Dataset Name", interactive=False)
|
286 |
-
|
287 |
-
refresh_btn = gr.Button("Refresh Stats")
|
288 |
-
|
289 |
-
# Set up event handlers
|
290 |
-
save_button.click(
|
291 |
-
save_dataset_manually,
|
292 |
-
[],
|
293 |
-
[status_output]
|
294 |
-
)
|
295 |
-
|
296 |
-
refresh_btn.click(
|
297 |
-
update_stats_event,
|
298 |
-
[],
|
299 |
-
[convo_count, next_save, last_save_time_display, dataset_name_display]
|
300 |
-
)
|
301 |
|
302 |
-
#
|
303 |
-
|
304 |
-
|
305 |
-
|
306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
307 |
)
|
308 |
|
309 |
-
#
|
310 |
-
|
311 |
-
|
312 |
-
[],
|
313 |
-
|
314 |
)
|
315 |
-
|
316 |
-
# Ensure we save on shutdown
|
317 |
-
import atexit
|
318 |
-
atexit.register(save_to_dataset)
|
319 |
|
320 |
-
|
321 |
-
if __name__ == "__main__":
|
322 |
-
demo.queue() # Enable queuing for better performance
|
323 |
-
demo.launch(share=True)
|
|
|
1 |
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
3 |
+
import gradio as gr
|
4 |
+
from threading import Thread
|
5 |
+
import os
|
6 |
+
from gradio_modal import Modal
|
7 |
|
|
|
8 |
checkpoint = "WillHeld/soft-raccoon"
|
9 |
+
device = "cuda"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
|
11 |
model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
|
12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
@spaces.GPU(duration=120)
|
14 |
+
def predict(message, history, temperature, top_p):
|
15 |
+
history.append({"role": "user", "content": message})
|
16 |
+
input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
|
18 |
|
19 |
+
# Create a streamer
|
20 |
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
21 |
|
22 |
+
# Set up generation parameters
|
23 |
generation_kwargs = {
|
24 |
"input_ids": inputs,
|
25 |
"max_new_tokens": 1024,
|
|
|
27 |
"top_p": float(top_p),
|
28 |
"do_sample": True,
|
29 |
"streamer": streamer,
|
30 |
+
"eos_token_id": 128009,
|
31 |
}
|
32 |
|
33 |
+
# Run generation in a separate thread
|
34 |
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
35 |
thread.start()
|
36 |
|
37 |
+
# Yield from the streamer as tokens are generated
|
38 |
partial_text = ""
|
|
|
|
|
39 |
for new_text in streamer:
|
40 |
partial_text += new_text
|
41 |
yield partial_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
+
# Function to handle the report submission
|
44 |
+
def submit_report(satisfaction, feedback_text):
|
45 |
+
# In a real application, you might save this to a database or file
|
46 |
+
print(f"Report submitted - Satisfaction: {satisfaction}, Feedback: {feedback_text}")
|
47 |
+
return "Thank you for your feedback! Your report has been submitted."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
|
49 |
+
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
with gr.Row():
|
51 |
with gr.Column(scale=3):
|
52 |
+
chatbot = gr.ChatInterface(
|
53 |
+
predict,
|
54 |
+
additional_inputs=[
|
55 |
+
gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
|
56 |
+
gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
57 |
],
|
58 |
+
type="messages"
|
59 |
)
|
60 |
|
61 |
with gr.Column(scale=1):
|
62 |
+
report_button = gr.Button("File a Report", variant="primary")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
# Create the modal with feedback form components
|
65 |
+
with Modal(visible=False) as feedback_modal:
|
66 |
+
with gr.Column():
|
67 |
+
gr.Markdown("## We value your feedback!")
|
68 |
+
gr.Markdown("Please tell us about your experience with the model.")
|
69 |
+
|
70 |
+
satisfaction = gr.Radio(
|
71 |
+
["Very satisfied", "Satisfied", "Neutral", "Unsatisfied", "Very unsatisfied"],
|
72 |
+
label="How satisfied are you with the model's responses?",
|
73 |
+
value="Neutral"
|
74 |
+
)
|
75 |
+
|
76 |
+
feedback_text = gr.Textbox(
|
77 |
+
lines=5,
|
78 |
+
label="Please provide any additional feedback or describe issues you encountered:",
|
79 |
+
placeholder="Enter your detailed feedback here..."
|
80 |
+
)
|
81 |
+
|
82 |
+
submit_button = gr.Button("Submit Feedback", variant="primary")
|
83 |
+
response_text = gr.Textbox(label="Status", interactive=False)
|
84 |
+
|
85 |
+
# Connect the "File a Report" button to show the modal
|
86 |
+
report_button.click(
|
87 |
+
lambda: Modal(visible=True),
|
88 |
+
None,
|
89 |
+
feedback_modal
|
90 |
)
|
91 |
|
92 |
+
# Connect the submit button to the submit_report function
|
93 |
+
submit_button.click(
|
94 |
+
submit_report,
|
95 |
+
inputs=[satisfaction, feedback_text],
|
96 |
+
outputs=response_text
|
97 |
)
|
|
|
|
|
|
|
|
|
98 |
|
99 |
+
demo.launch()
|
|
|
|
|
|