keeperballon commited on
Commit
80bd43d
·
verified ·
1 Parent(s): 8f82122

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -414
app.py CHANGED
@@ -9,79 +9,10 @@ APP_DESCRIPTION = "Access and chat with multiple language models without requiri
9
 
10
  # Load environment variables
11
  ACCESS_TOKEN = os.getenv("HF_TOKEN")
12
- print("Access token loaded.")
13
-
14
  client = OpenAI(
15
  base_url="https://api-inference.huggingface.co/v1/",
16
  api_key=ACCESS_TOKEN,
17
  )
18
- print("OpenAI client initialized.")
19
-
20
-
21
- def respond(
22
- message,
23
- history,
24
- system_message,
25
- max_tokens,
26
- temperature,
27
- top_p,
28
- frequency_penalty,
29
- seed,
30
- custom_model
31
- ):
32
- print(f"Received message: {message}")
33
- print(f"Selected model: {custom_model}")
34
-
35
- # Convert seed to None if -1 (meaning random)
36
- if seed == -1:
37
- seed = None
38
-
39
- messages = [{"role": "system", "content": system_message}]
40
-
41
- # Add conversation history to the context
42
- for val in history:
43
- user_part = val[0]
44
- assistant_part = val[1]
45
- if user_part:
46
- messages.append({"role": "user", "content": user_part})
47
- if assistant_part:
48
- messages.append({"role": "assistant", "content": assistant_part})
49
-
50
- # Append the latest user message
51
- messages.append({"role": "user", "content": message})
52
-
53
- # If user provided a model, use that; otherwise, fall back to a default model
54
- model_to_use = custom_model.strip() if custom_model.strip() != "" else "meta-llama/Llama-3.3-70B-Instruct"
55
-
56
- # Create a copy of the history and add the new user message
57
- new_history = list(history)
58
- new_history.append((message, ""))
59
- current_response = ""
60
-
61
- try:
62
- for message_chunk in client.chat.completions.create(
63
- model=model_to_use,
64
- max_tokens=max_tokens,
65
- stream=True,
66
- temperature=temperature,
67
- top_p=top_p,
68
- frequency_penalty=frequency_penalty,
69
- seed=seed,
70
- messages=messages,
71
- ):
72
- token_text = message_chunk.choices[0].delta.content
73
- if token_text is not None: # Handle None type in response
74
- current_response += token_text
75
- # Update just the last message in history
76
- new_history[-1] = (message, current_response)
77
- yield new_history
78
- except Exception as e:
79
- error_message = f"Error: {str(e)}\n\nPlease check your model selection and parameters, or try again later."
80
- new_history[-1] = (message, error_message)
81
- yield new_history
82
-
83
- print("Completed response generation.")
84
-
85
 
86
  # Model categories for better organization
87
  MODEL_CATEGORIES = {
@@ -125,370 +56,124 @@ MODEL_CATEGORIES = {
125
  ]
126
  }
127
 
128
- # Flatten the model list for search functionality
129
- ALL_MODELS = []
130
- for category, models in MODEL_CATEGORIES.items():
131
- ALL_MODELS.extend(models)
132
-
133
 
134
  def get_model_info(model_name):
135
- """Extract and format model information for display"""
136
  parts = model_name.split('/')
137
  if len(parts) != 2:
138
  return f"**Model:** {model_name}\n**Format:** Unknown"
139
-
140
- org = parts[0]
141
- model = parts[1]
142
-
143
- # Extract numbers from model name to determine size
144
  import re
145
  size_match = re.search(r'(\d+\.?\d*)B', model)
146
  size = size_match.group(1) + "B" if size_match else "Unknown"
147
-
148
  return f"**Organization:** {org}\n**Model:** {model}\n**Size:** {size}"
149
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
- def set_model_and_update_info(model_name):
152
- """Set the selected model and update the model info display"""
153
- # This function is called when a model is selected (either clicked or via API)
 
154
  try:
155
- # Get model info
156
- model_info = get_model_info(model_name)
157
-
158
- # Return both the model name and the model info
159
- return model_name, model_info
 
 
 
 
 
 
 
 
 
 
160
  except Exception as e:
161
- print(f"Error in set_model_and_update_info: {e}")
162
- return model_name, f"**Error loading model info**: {str(e)}"
163
-
164
-
165
- def filter_models(search_term):
166
- """Filter models based on search term across all categories"""
167
- if not search_term:
168
- return MODEL_CATEGORIES
169
-
170
- filtered_categories = {}
171
- for category, models in MODEL_CATEGORIES.items():
172
- filtered_models = [m for m in models if search_term.lower() in m.lower()]
173
- if filtered_models:
174
- filtered_categories[category] = filtered_models
175
-
176
- return filtered_categories
177
-
178
-
179
- def update_model_display(search_term=""):
180
- """Update the model selection UI based on search term"""
181
- filtered_categories = filter_models(search_term)
182
-
183
- # Create HTML for model display with a more direct approach
184
- html = """
185
- <div style='max-height: 400px; overflow-y: auto;'>
186
- <script>
187
- // Direct model selection function - more reliable
188
- function selectModel(modelName) {
189
- // Get the textbox element by its ID
190
- const modelInput = document.getElementById('custom-model-input');
191
- if (modelInput) {
192
- // Set the value
193
- modelInput.value = modelName;
194
-
195
- // Create and dispatch change event
196
- const event = new Event('change', { bubbles: true });
197
- modelInput.dispatchEvent(event);
198
-
199
- // Look for the hidden trigger button and click it
200
- const triggerBtn = document.querySelector('button[value="Select Model"]');
201
- if (triggerBtn) {
202
- triggerBtn.click();
203
- }
204
-
205
- console.log('Selected model:', modelName);
206
- } else {
207
- console.error('Model input element not found');
208
- }
209
- }
210
- </script>
211
- """
212
-
213
- # Add models by category
214
- for category, models in filtered_categories.items():
215
- html += f"<h3>{category}</h3><div style='display: grid; grid-template-columns: repeat(auto-fill, minmax(250px, 1fr)); gap: 10px;'>"
216
-
217
- for model in models:
218
- model_short = model.split('/')[-1]
219
- escaped_model = model.replace("'", "\\'").replace('"', '\\"')
220
- html += f"""
221
- <div class='model-card'
222
- style='border: 1px solid #ddd; border-radius: 8px; padding: 12px; cursor: pointer; transition: all 0.2s;
223
- background: linear-gradient(145deg, #f0f0f0, #ffffff); box-shadow: 0 4px 6px rgba(0,0,0,0.1);'
224
- onclick="selectModel('{escaped_model}')">
225
- <div style='font-weight: bold; margin-bottom: 6px; color: #1a73e8;'>{model_short}</div>
226
- <div style='font-size: 0.8em; color: #666;'>{model.split('/')[0]}</div>
227
- </div>
228
- """
229
- html += "</div>"
230
-
231
- if not filtered_categories:
232
- html += "<p>No models found matching your search.</p>"
233
-
234
- html += "</div>"
235
- return html
236
-
237
-
238
- # Create custom CSS for better styling
239
- custom_css = """
240
- #app-container {
241
- max-width: 1200px;
242
- margin: 0 auto;
243
- padding: 20px;
244
- }
245
-
246
- #chat-container {
247
- border-radius: 12px;
248
- box-shadow: 0 8px 16px rgba(0,0,0,0.1);
249
- overflow: hidden;
250
- border: 1px solid #e0e0e0;
251
- }
252
-
253
- .contain {
254
- background: linear-gradient(135deg, #f5f7fa 0%, #e4e7eb 100%);
255
- }
256
-
257
- h1, h2, h3 {
258
- font-family: 'Poppins', sans-serif;
259
- }
260
-
261
- h1 {
262
- background: linear-gradient(90deg, #2b6cb0, #4299e1);
263
- -webkit-background-clip: text;
264
- -webkit-text-fill-color: transparent;
265
- font-weight: 700;
266
- letter-spacing: -0.5px;
267
- margin-bottom: 8px;
268
- }
269
-
270
- .parameter-row {
271
- display: flex;
272
- gap: 10px;
273
- margin-bottom: 10px;
274
- }
275
-
276
- .model-card:hover {
277
- transform: translateY(-2px);
278
- box-shadow: 0 6px 12px rgba(0,0,0,0.15);
279
- border-color: #4299e1;
280
- }
281
-
282
- .footer {
283
- text-align: center;
284
- margin-top: 20px;
285
- font-size: 0.8em;
286
- color: #666;
287
- }
288
-
289
- /* Status indicator styles */
290
- .status-indicator {
291
- display: inline-block;
292
- width: 10px;
293
- height: 10px;
294
- border-radius: 50%;
295
- margin-right: 6px;
296
- }
297
-
298
- .status-active {
299
- background-color: #10B981;
300
- animation: pulse 2s infinite;
301
- }
302
 
303
- @keyframes pulse {
304
- 0% {
305
- box-shadow: 0 0 0 0 rgba(16, 185, 129, 0.7);
306
- }
307
- 70% {
308
- box-shadow: 0 0 0 5px rgba(16, 185, 129, 0);
309
- }
310
- 100% {
311
- box-shadow: 0 0 0 0 rgba(16, 185, 129, 0);
312
- }
313
- }
314
- """
315
 
316
- with gr.Blocks(css=custom_css, title=APP_TITLE, theme=gr.themes.Soft()) as demo:
317
- gr.HTML(f"""
318
- <div id="app-container">
319
- <div style="text-align: center; padding: 20px 0;">
320
- <h1 style="font-size: 2.5rem;">{APP_TITLE}</h1>
321
- <p style="font-size: 1.1rem; color: #555;">{APP_DESCRIPTION}</p>
322
- <div style="margin-top: 10px;">
323
- <span class="status-indicator status-active"></span>
324
- <span>Service Active</span>
325
- <span style="margin-left: 15px;">Last Updated: {datetime.now().strftime('%Y-%m-%d')}</span>
326
- </div>
327
- </div>
328
- </div>
329
- """)
330
-
331
  with gr.Row():
332
  with gr.Column(scale=2):
333
- # Model selection panel - MOVED TO THE LEFT SIDE
334
- gr.HTML("<div style='border: 1px solid #e0e0e0; border-radius: 10px; padding: 15px;'>")
335
- gr.HTML("<h3 style='margin-top: 0;'>Model Selection</h3>")
336
-
337
- # Custom model input (this is what the respond function sees)
338
- custom_model_box = gr.Textbox(
339
- value="Qwen/Qwen3-32B", # Changed default model to Qwen
340
- label="Selected Model",
341
- elem_id="custom-model-input"
 
 
 
 
 
342
  )
343
-
344
- # Search box
345
- model_search_box = gr.Textbox(
346
- label="Search Models",
347
- placeholder="Type to filter models...",
348
- lines=1
349
  )
350
-
351
- # Dynamic model display area
352
- model_display = gr.HTML(update_model_display())
353
-
354
- # Model information display
355
- gr.HTML("<h4>Current Model Info</h4>")
356
- model_info_display = gr.Markdown(get_model_info("Qwen/Qwen3-32B"))
357
- gr.HTML("</div>")
358
-
359
  with gr.Column(scale=3):
360
- # Main chat interface
361
- chatbot = gr.Chatbot(
362
- height=550,
363
- show_copy_button=True,
364
- placeholder="Select a model and begin chatting",
365
- layout="panel",
366
- elem_id="chat-container"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
  )
368
-
369
- with gr.Row():
370
- with gr.Column(scale=8):
371
- msg = gr.Textbox(
372
- show_label=False,
373
- placeholder="Type your message here...",
374
- container=False,
375
- scale=8
376
- )
377
- with gr.Column(scale=1, min_width=70):
378
- submit_btn = gr.Button("Send", variant="primary", scale=1)
379
-
380
- with gr.Accordion("Conversation Settings", open=False):
381
- system_message_box = gr.Textbox(
382
- value="You are a helpful assistant.",
383
- placeholder="System prompt that guides the assistant's behavior",
384
- label="System Prompt",
385
- lines=2
386
- )
387
-
388
- # Use standard Row/Column layout instead of tabs that might not be available
389
- gr.HTML("<h3>Basic Parameters</h3>")
390
- with gr.Row():
391
- with gr.Column():
392
- max_tokens_slider = gr.Slider(
393
- minimum=1,
394
- maximum=4096,
395
- value=512,
396
- step=1,
397
- label="Max new tokens"
398
- )
399
- with gr.Column():
400
- temperature_slider = gr.Slider(
401
- minimum=0.1,
402
- maximum=4.0,
403
- value=0.7,
404
- step=0.1,
405
- label="Temperature"
406
- )
407
-
408
- gr.HTML("<h3>Advanced Parameters</h3>")
409
- with gr.Row():
410
- with gr.Column():
411
- top_p_slider = gr.Slider(
412
- minimum=0.1,
413
- maximum=1.0,
414
- value=0.95,
415
- step=0.05,
416
- label="Top-P"
417
- )
418
- with gr.Column():
419
- frequency_penalty_slider = gr.Slider(
420
- minimum=-2.0,
421
- maximum=2.0,
422
- value=0.0,
423
- step=0.1,
424
- label="Frequency Penalty"
425
- )
426
-
427
- seed_slider = gr.Slider(
428
- minimum=-1,
429
- maximum=65535,
430
- value=-1,
431
- step=1,
432
- label="Seed (-1 for random)"
433
- )
434
-
435
- # Footer
436
- gr.HTML("""
437
- <div class="footer">
438
- <p>Created with Gradio • Powered by Hugging Face Inference API</p>
439
- <p>This interface allows you to chat with various language models without requiring a GPU</p>
440
- </div>
441
- """)
442
-
443
- # Add a hidden button to trigger model selection via JavaScript
444
- trigger_model_selection = gr.Button("Select Model", visible=False)
445
-
446
- # Set up event handlers
447
- msg.submit(
448
- fn=respond,
449
- inputs=[msg, chatbot, system_message_box, max_tokens_slider, temperature_slider,
450
- top_p_slider, frequency_penalty_slider, seed_slider, custom_model_box],
451
- outputs=chatbot,
452
- queue=True
453
- ).then(
454
- lambda: "", # Clear the message box after sending
455
- None,
456
- [msg]
457
- )
458
-
459
- submit_btn.click(
460
- fn=respond,
461
- inputs=[msg, chatbot, system_message_box, max_tokens_slider, temperature_slider,
462
- top_p_slider, frequency_penalty_slider, seed_slider, custom_model_box],
463
- outputs=chatbot,
464
- queue=True
465
- ).then(
466
- lambda: "", # Clear the message box after sending
467
- None,
468
- [msg]
469
- )
470
-
471
- # Update model display when search changes
472
- model_search_box.change(
473
- fn=lambda x: update_model_display(x),
474
- inputs=model_search_box,
475
- outputs=model_display
476
- )
477
-
478
- # Update model info when selection changes
479
- custom_model_box.change(
480
- fn=set_model_and_update_info,
481
- inputs=custom_model_box,
482
- outputs=[custom_model_box, model_info_display]
483
- )
484
-
485
- # Connect the hidden trigger button to update model info
486
- trigger_model_selection.click(
487
- fn=set_model_and_update_info,
488
- inputs=custom_model_box,
489
- outputs=[custom_model_box, model_info_display]
490
- )
491
 
492
- if __name__ == "__main__":
493
- print("Launching the enhanced multi-model chat interface.")
494
- demo.launch()
 
9
 
10
  # Load environment variables
11
  ACCESS_TOKEN = os.getenv("HF_TOKEN")
 
 
12
  client = OpenAI(
13
  base_url="https://api-inference.huggingface.co/v1/",
14
  api_key=ACCESS_TOKEN,
15
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Model categories for better organization
18
  MODEL_CATEGORIES = {
 
56
  ]
57
  }
58
 
59
+ # Flatten the model list
60
+ ALL_MODELS = [m for models in MODEL_CATEGORIES.values() for m in models]
 
 
 
61
 
62
  def get_model_info(model_name):
 
63
  parts = model_name.split('/')
64
  if len(parts) != 2:
65
  return f"**Model:** {model_name}\n**Format:** Unknown"
66
+ org, model = parts
 
 
 
 
67
  import re
68
  size_match = re.search(r'(\d+\.?\d*)B', model)
69
  size = size_match.group(1) + "B" if size_match else "Unknown"
 
70
  return f"**Organization:** {org}\n**Model:** {model}\n**Size:** {size}"
71
 
72
+ def respond(
73
+ message,
74
+ history,
75
+ system_message,
76
+ max_tokens,
77
+ temperature,
78
+ top_p,
79
+ frequency_penalty,
80
+ seed,
81
+ selected_model
82
+ ):
83
+ # Prepare messages
84
+ if seed == -1:
85
+ seed = None
86
+ messages = [{"role": "system", "content": system_message}]
87
+ for user_msg, assistant_msg in history:
88
+ if user_msg:
89
+ messages.append({"role": "user", "content": user_msg})
90
+ if assistant_msg:
91
+ messages.append({"role": "assistant", "content": assistant_msg})
92
+ messages.append({"role": "user", "content": message})
93
 
94
+ model_to_use = selected_model or ALL_MODELS[0]
95
+
96
+ new_history = list(history) + [(message, "")]
97
+ current_response = ""
98
  try:
99
+ for chunk in client.chat.completions.create(
100
+ model=model_to_use,
101
+ max_tokens=max_tokens,
102
+ stream=True,
103
+ temperature=temperature,
104
+ top_p=top_p,
105
+ frequency_penalty=frequency_penalty,
106
+ seed=seed,
107
+ messages=messages,
108
+ ):
109
+ delta = chunk.choices[0].delta.content
110
+ if delta:
111
+ current_response += delta
112
+ new_history[-1] = (message, current_response)
113
+ yield new_history
114
  except Exception as e:
115
+ err = f"Error: {e}"
116
+ new_history[-1] = (message, err)
117
+ yield new_history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ with gr.Blocks(title=APP_TITLE, theme=gr.themes.Soft()) as demo:
120
+ gr.Markdown(f"## {APP_TITLE}\n\n{APP_DESCRIPTION}")
 
 
 
 
 
 
 
 
 
 
121
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  with gr.Row():
123
  with gr.Column(scale=2):
124
+ # Model selection via Dropdown
125
+ selected_model = gr.Dropdown(
126
+ choices=ALL_MODELS,
127
+ value=ALL_MODELS[0],
128
+ label="Select Model"
129
+ )
130
+ model_info = gr.Markdown(get_model_info(ALL_MODELS[0]))
131
+
132
+ def update_info(model_name):
133
+ return get_model_info(model_name)
134
+ selected_model.change(
135
+ fn=update_info,
136
+ inputs=[selected_model],
137
+ outputs=[model_info]
138
  )
139
+
140
+ # Conversation settings
141
+ system_message = gr.Textbox(
142
+ value="You are a helpful assistant.",
143
+ label="System Prompt",
144
+ lines=2
145
  )
146
+
147
+ max_tokens = gr.Slider(1, 4096, value=512, label="Max New Tokens")
148
+ temperature = gr.Slider(0.1, 4.0, value=0.7, step=0.1, label="Temperature")
149
+ top_p = gr.Slider(0.1, 1.0, value=0.95, step=0.05, label="Top-P")
150
+ freq_penalty = gr.Slider(-2.0, 2.0, value=0.0, step=0.1, label="Frequency Penalty")
151
+ seed = gr.Slider(-1, 65535, value=-1, step=1, label="Seed (-1 random)")
152
+
 
 
153
  with gr.Column(scale=3):
154
+ chatbot = gr.Chatbot()
155
+ msg = gr.Textbox(placeholder="Type your message here...", show_label=False)
156
+ send_btn = gr.Button("Send")
157
+
158
+ send_btn.click(
159
+ fn=respond,
160
+ inputs=[
161
+ msg, chatbot, system_message,
162
+ max_tokens, temperature, top_p,
163
+ freq_penalty, seed, selected_model
164
+ ],
165
+ outputs=[chatbot],
166
+ queue=True
167
+ )
168
+ msg.submit(
169
+ fn=respond,
170
+ inputs=[
171
+ msg, chatbot, system_message,
172
+ max_tokens, temperature, top_p,
173
+ freq_penalty, seed, selected_model
174
+ ],
175
+ outputs=[chatbot],
176
+ queue=True
177
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
 
179
+ demo.launch()