WillHeld commited on
Commit
b9a25dd
·
verified ·
1 Parent(s): dff15d7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -282
app.py CHANGED
@@ -1,112 +1,25 @@
1
  import spaces
2
- import os
3
- import uuid
4
- import time
5
- import json
6
- import torch
7
- from datetime import datetime, timedelta
8
- from threading import Thread
9
- from pathlib import Path
10
-
11
- # Gradio and HuggingFace imports
12
- import gradio as gr
13
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
14
- from datasets import Dataset
15
- from huggingface_hub import HfApi, login
 
 
16
 
17
- # Model configuration
18
  checkpoint = "WillHeld/soft-raccoon"
19
-
20
- # Set device based on availability
21
- if torch.cuda.is_available():
22
- device = "cuda"
23
- else:
24
- device = "cpu"
25
- print("CUDA not available, using CPU instead. This will be much slower.")
26
-
27
- # Dataset configuration
28
- DATASET_NAME = "WillHeld/soft-raccoon-conversations" # Change to your username
29
- SAVE_INTERVAL_MINUTES = 5 # Save data every 5 minutes
30
- last_save_time = datetime.now()
31
-
32
- # Initialize model and tokenizer
33
- print(f"Loading model from {checkpoint}...")
34
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
35
  model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
36
 
37
- # Data storage
38
- conversations = []
39
-
40
- # Hugging Face authentication
41
- # Uncomment this line to login with your token
42
- # login(token=os.environ.get("HF_TOKEN"))
43
-
44
- def save_to_dataset():
45
- """Save the current conversations to a HuggingFace dataset"""
46
- if not conversations:
47
- return None, f"No conversations to save. Last attempt: {datetime.now().strftime('%H:%M:%S')}"
48
-
49
- # Convert conversations to dataset format
50
- dataset_dict = {
51
- "conversation_id": [],
52
- "timestamp": [],
53
- "messages": [],
54
- "metadata": []
55
- }
56
-
57
- for conv in conversations:
58
- dataset_dict["conversation_id"].append(conv["conversation_id"])
59
- dataset_dict["timestamp"].append(conv["timestamp"])
60
- dataset_dict["messages"].append(json.dumps(conv["messages"]))
61
- dataset_dict["metadata"].append(json.dumps(conv["metadata"]))
62
-
63
- # Create dataset
64
- dataset = Dataset.from_dict(dataset_dict)
65
-
66
- try:
67
- # Push to hub
68
- dataset.push_to_hub(DATASET_NAME)
69
- status_msg = f"Successfully saved {len(conversations)} conversations to {DATASET_NAME}"
70
- print(status_msg)
71
- except Exception as e:
72
- # Save locally as fallback
73
- local_path = f"local_dataset_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
74
- dataset.save_to_disk(local_path)
75
- status_msg = f"Error pushing to hub: {str(e)}. Saved locally to '{local_path}'"
76
- print(status_msg)
77
-
78
- return dataset, status_msg
79
-
80
  @spaces.GPU(duration=120)
81
- def chat_model(message, history, temperature=0.7, top_p=0.9):
82
- """Chat function for use with ChatInterface"""
83
- conversation_id = getattr(chat_model, "conversation_id", None)
84
- if conversation_id is None:
85
- conversation_id = str(uuid.uuid4())
86
- chat_model.conversation_id = conversation_id
87
-
88
- # Format chat history for the model
89
- formatted_history = []
90
- for h in history:
91
- formatted_history.append({"role": "user", "content": h["content"] if h["role"] == "user" else ""})
92
- if h["role"] == "assistant":
93
- formatted_history.append({"role": "assistant", "content": h["content"]})
94
-
95
- # Add the current message
96
- formatted_history.append({"role": "user", "content": message})
97
-
98
- # Prepare input for the model
99
- input_text = tokenizer.apply_chat_template(
100
- formatted_history,
101
- tokenize=False,
102
- add_generation_prompt=True
103
- )
104
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
105
 
106
- # Set up streaming
107
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
108
 
109
- # Generation parameters
110
  generation_kwargs = {
111
  "input_ids": inputs,
112
  "max_new_tokens": 1024,
@@ -114,210 +27,73 @@ def chat_model(message, history, temperature=0.7, top_p=0.9):
114
  "top_p": float(top_p),
115
  "do_sample": True,
116
  "streamer": streamer,
 
117
  }
118
 
119
- # Generate in a separate thread
120
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
121
  thread.start()
122
 
123
- # Initialize response
124
  partial_text = ""
125
-
126
- # Yield partial text as it's generated
127
  for new_text in streamer:
128
  partial_text += new_text
129
  yield partial_text
130
-
131
- # Store conversation data in the global conversations list
132
- formatted_history.append({"role": "assistant", "content": partial_text})
133
-
134
- # Find existing conversation or create new one
135
- existing_conv = next((c for c in conversations if c["conversation_id"] == conversation_id), None)
136
-
137
- # Update or create conversation record
138
- current_time = datetime.now().isoformat()
139
- if existing_conv:
140
- # Update existing conversation
141
- existing_conv["messages"] = formatted_history
142
- existing_conv["metadata"]["last_updated"] = current_time
143
- existing_conv["metadata"]["temperature"] = temperature
144
- existing_conv["metadata"]["top_p"] = top_p
145
- else:
146
- # Create new conversation record
147
- conversations.append({
148
- "conversation_id": conversation_id,
149
- "timestamp": current_time,
150
- "messages": formatted_history,
151
- "metadata": {
152
- "model": checkpoint,
153
- "temperature": temperature,
154
- "top_p": top_p,
155
- "last_updated": current_time
156
- }
157
- })
158
-
159
- # Check if it's time to save based on elapsed time
160
- global last_save_time
161
- current_time_dt = datetime.now()
162
- if current_time_dt - last_save_time > timedelta(minutes=SAVE_INTERVAL_MINUTES):
163
- save_to_dataset()
164
- last_save_time = current_time_dt
165
-
166
- def save_dataset_manually():
167
- """Manually trigger dataset save and return status"""
168
- _, status = save_to_dataset()
169
- return status
170
-
171
- def get_stats():
172
- """Get current stats about conversations and saving"""
173
- mins_until_save = SAVE_INTERVAL_MINUTES - (datetime.now() - last_save_time).seconds // 60
174
- if mins_until_save < 0:
175
- mins_until_save = 0
176
-
177
- return {
178
- "conversation_count": len(conversations),
179
- "next_save": f"In {mins_until_save} minutes",
180
- "last_save": last_save_time.strftime('%H:%M:%S'),
181
- "dataset_name": DATASET_NAME
182
- }
183
-
184
- def update_stats_event():
185
- """Update the stats UI elements"""
186
- stats = get_stats()
187
- return [
188
- stats["conversation_count"],
189
- stats["next_save"],
190
- stats["last_save"],
191
- stats["dataset_name"]
192
- ]
193
 
194
- # Create a Stanford theme
195
- theme = gr.themes.Default(
196
- primary_hue=gr.themes.utils.colors.red,
197
- secondary_hue=gr.themes.utils.colors.gray,
198
- neutral_hue=gr.themes.utils.colors.gray,
199
- font=[gr.themes.GoogleFont("Source Sans Pro"), "ui-sans-serif", "system-ui"]
200
- ).set(
201
- button_primary_background_fill="#8C1515",
202
- button_primary_background_fill_hover="#771212",
203
- button_primary_text_color="white",
204
- slider_color="#8C1515",
205
- block_title_text_color="#8C1515",
206
- block_label_text_color="#4D4F53",
207
- input_border_color_focus="#8C1515",
208
- checkbox_background_color_selected="#8C1515",
209
- checkbox_border_color_selected="#8C1515",
210
- button_secondary_border_color="#4D4F53",
211
- block_title_background_fill="#f5f5f5",
212
- block_label_background_fill="#f9f9f9"
213
- )
214
 
215
- # Custom CSS
216
- css = """
217
- .gradio-container {
218
- font-family: 'Source Sans Pro', sans-serif !important;
219
- }
220
- .footer {
221
- color: #4D4F53 !important;
222
- font-size: 0.85em !important;
223
- }
224
- """
225
-
226
- # Set up the Gradio app with Blocks for more control
227
- with gr.Blocks(theme=theme, title="Stanford Soft Raccoon Chat", css=css) as demo:
228
- # Create a timer for periodic updates
229
- timer = gr.Timer(30, update_stats_event) # Update every 30 seconds
230
-
231
  with gr.Row():
232
  with gr.Column(scale=3):
233
- # Create the chatbot component directly
234
- chatbot = gr.Chatbot(
235
- label="Soft Raccoon Chat",
236
- avatar_images=(None, "🌲"), # Stanford tree emoji
237
- height=600,
238
- placeholder="<strong>Soft Raccoon AI Assistant</strong><br>Ask me anything!"
239
- )
240
-
241
- # Create sliders for temperature and top_p
242
- with gr.Accordion("Generation Parameters", open=False):
243
- temperature = gr.Slider(
244
- minimum=0.1,
245
- maximum=2.0,
246
- value=0.7,
247
- step=0.1,
248
- label="Temperature"
249
- )
250
- top_p = gr.Slider(
251
- minimum=0.1,
252
- maximum=1.0,
253
- value=0.9,
254
- step=0.05,
255
- label="Top-P"
256
- )
257
-
258
- # Create the ChatInterface
259
- chat_interface = gr.ChatInterface(
260
- fn=chat_model,
261
- chatbot=chatbot,
262
- additional_inputs=[temperature, top_p],
263
- type="messages", # This is important for compatibility
264
- title="Stanford Soft Raccoon Chat",
265
- description="AI assistant powered by the Soft Raccoon language model",
266
- examples=[
267
- ["Tell me about Stanford University", 0.7, 0.9],
268
- ["How can I learn about artificial intelligence?", 0.8, 0.95],
269
- ["What's your favorite book?", 0.6, 0.85]
270
  ],
271
- cache_examples=True,
272
  )
273
 
274
  with gr.Column(scale=1):
275
- with gr.Group():
276
- gr.Markdown("### Dataset Controls")
277
- save_button = gr.Button("Save conversations now", variant="secondary")
278
- status_output = gr.Textbox(label="Save Status", interactive=False)
279
-
280
- with gr.Row():
281
- convo_count = gr.Number(label="Total Conversations", interactive=False)
282
- next_save = gr.Textbox(label="Next Auto-Save", interactive=False)
283
-
284
- last_save_time_display = gr.Textbox(label="Last Save Time", interactive=False)
285
- dataset_name_display = gr.Textbox(label="Dataset Name", interactive=False)
286
-
287
- refresh_btn = gr.Button("Refresh Stats")
288
-
289
- # Set up event handlers
290
- save_button.click(
291
- save_dataset_manually,
292
- [],
293
- [status_output]
294
- )
295
-
296
- refresh_btn.click(
297
- update_stats_event,
298
- [],
299
- [convo_count, next_save, last_save_time_display, dataset_name_display]
300
- )
301
 
302
- # Connect the timer to update the stats displays
303
- timer.stream(
304
- fn=lambda: None, # No-op function
305
- inputs=None,
306
- outputs=None,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  )
308
 
309
- # Initial update of stats on page load
310
- demo.load(
311
- update_stats_event,
312
- [],
313
- [convo_count, next_save, last_save_time_display, dataset_name_display]
314
  )
315
-
316
- # Ensure we save on shutdown
317
- import atexit
318
- atexit.register(save_to_dataset)
319
 
320
- # Launch the app
321
- if __name__ == "__main__":
322
- demo.queue() # Enable queuing for better performance
323
- demo.launch(share=True)
 
1
  import spaces
 
 
 
 
 
 
 
 
 
 
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
3
+ import gradio as gr
4
+ from threading import Thread
5
+ import os
6
+ from gradio_modal import Modal
7
 
 
8
  checkpoint = "WillHeld/soft-raccoon"
9
+ device = "cuda"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  tokenizer = AutoTokenizer.from_pretrained(checkpoint)
11
  model = AutoModelForCausalLM.from_pretrained(checkpoint).to(device)
12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  @spaces.GPU(duration=120)
14
+ def predict(message, history, temperature, top_p):
15
+ history.append({"role": "user", "content": message})
16
+ input_text = tokenizer.apply_chat_template(history, tokenize=False, add_generation_prompt=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  inputs = tokenizer.encode(input_text, return_tensors="pt").to(device)
18
 
19
+ # Create a streamer
20
  streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
21
 
22
+ # Set up generation parameters
23
  generation_kwargs = {
24
  "input_ids": inputs,
25
  "max_new_tokens": 1024,
 
27
  "top_p": float(top_p),
28
  "do_sample": True,
29
  "streamer": streamer,
30
+ "eos_token_id": 128009,
31
  }
32
 
33
+ # Run generation in a separate thread
34
  thread = Thread(target=model.generate, kwargs=generation_kwargs)
35
  thread.start()
36
 
37
+ # Yield from the streamer as tokens are generated
38
  partial_text = ""
 
 
39
  for new_text in streamer:
40
  partial_text += new_text
41
  yield partial_text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ # Function to handle the report submission
44
+ def submit_report(satisfaction, feedback_text):
45
+ # In a real application, you might save this to a database or file
46
+ print(f"Report submitted - Satisfaction: {satisfaction}, Feedback: {feedback_text}")
47
+ return "Thank you for your feedback! Your report has been submitted."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
+ with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  with gr.Row():
51
  with gr.Column(scale=3):
52
+ chatbot = gr.ChatInterface(
53
+ predict,
54
+ additional_inputs=[
55
+ gr.Slider(0.1, 2.0, value=0.7, step=0.1, label="Temperature"),
56
+ gr.Slider(0.1, 1.0, value=0.9, step=0.05, label="Top-P")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  ],
58
+ type="messages"
59
  )
60
 
61
  with gr.Column(scale=1):
62
+ report_button = gr.Button("File a Report", variant="primary")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ # Create the modal with feedback form components
65
+ with Modal(visible=False) as feedback_modal:
66
+ with gr.Column():
67
+ gr.Markdown("## We value your feedback!")
68
+ gr.Markdown("Please tell us about your experience with the model.")
69
+
70
+ satisfaction = gr.Radio(
71
+ ["Very satisfied", "Satisfied", "Neutral", "Unsatisfied", "Very unsatisfied"],
72
+ label="How satisfied are you with the model's responses?",
73
+ value="Neutral"
74
+ )
75
+
76
+ feedback_text = gr.Textbox(
77
+ lines=5,
78
+ label="Please provide any additional feedback or describe issues you encountered:",
79
+ placeholder="Enter your detailed feedback here..."
80
+ )
81
+
82
+ submit_button = gr.Button("Submit Feedback", variant="primary")
83
+ response_text = gr.Textbox(label="Status", interactive=False)
84
+
85
+ # Connect the "File a Report" button to show the modal
86
+ report_button.click(
87
+ lambda: Modal(visible=True),
88
+ None,
89
+ feedback_modal
90
  )
91
 
92
+ # Connect the submit button to the submit_report function
93
+ submit_button.click(
94
+ submit_report,
95
+ inputs=[satisfaction, feedback_text],
96
+ outputs=response_text
97
  )
 
 
 
 
98
 
99
+ demo.launch()