nuojohnchen commited on
Commit
ac66727
·
verified ·
1 Parent(s): b1f5007

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -30
app.py CHANGED
@@ -224,6 +224,40 @@ def on_model_series_change(model_series):
224
  return gr.update(choices=APOLLO_MODELS[model_series], value=APOLLO_MODELS[model_series][0])
225
  return gr.update(choices=[], value=None)
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  # Create Gradio interface
228
  with gr.Blocks(css=css) as demo:
229
  # Title and description
@@ -269,9 +303,9 @@ with gr.Blocks(css=css) as demo:
269
  label="Maximum Tokens"
270
  )
271
 
272
- # Load model button
273
- load_button = gr.Button("Load Model")
274
- model_status = gr.Textbox(label="Model Status", value="No model loaded yet")
275
 
276
  with gr.Column(scale=2):
277
  # Chat interface
@@ -292,40 +326,16 @@ with gr.Blocks(css=css) as demo:
292
  outputs=model_name
293
  )
294
 
295
- # Load model
296
- load_button.click(
297
- fn=load_model,
298
- inputs=model_name,
299
- outputs=model_status
300
- )
301
-
302
- # Bind message submission
303
- def process_message(message, chat_history):
304
- """Process user message and generate response"""
305
- if message.strip() == "":
306
- return "", chat_history
307
-
308
- # Add user message to chat history
309
- chat_history = list(chat_history)
310
- chat_history.append((message, None))
311
-
312
- # Generate response
313
- response = generate_response_non_streaming(message, model_name.value, temperature.value, max_tokens.value)
314
-
315
- # Add response to chat history
316
- chat_history[-1] = (message, response)
317
-
318
- return "", chat_history
319
-
320
  submit_event = user_input.submit(
321
  fn=process_message,
322
- inputs=[user_input, chatbot],
323
  outputs=[user_input, chatbot]
324
  )
325
 
326
  submit_button.click(
327
  fn=process_message,
328
- inputs=[user_input, chatbot],
329
  outputs=[user_input, chatbot]
330
  )
331
 
 
224
  return gr.update(choices=APOLLO_MODELS[model_series], value=APOLLO_MODELS[model_series][0])
225
  return gr.update(choices=[], value=None)
226
 
227
+ def process_message(message, chat_history, model_series_value, model_name_value, temperature_value, max_tokens_value):
228
+ """Process user message and generate response"""
229
+ if message.strip() == "":
230
+ return "", chat_history
231
+
232
+ # 打印用户提交的消息,用于调试
233
+ print("instruction:", message)
234
+
235
+ # Add user message to chat history
236
+ chat_history = list(chat_history)
237
+ chat_history.append((message, None))
238
+
239
+ # 自动加载模型(如果需要)
240
+ global current_model, current_tokenizer, current_model_path
241
+ if current_model_path != model_name_value or current_model is None:
242
+ try:
243
+ load_result = load_model(model_name_value)
244
+ if "failed" in load_result.lower():
245
+ chat_history[-1] = (message, f"模型加载失败: {load_result}")
246
+ return "", chat_history
247
+ except Exception as e:
248
+ chat_history[-1] = (message, f"模型加载出错: {str(e)}")
249
+ return "", chat_history
250
+
251
+ # Generate response
252
+ try:
253
+ response = generate_response_non_streaming(message, model_name_value, temperature_value, max_tokens_value)
254
+ # Add response to chat history
255
+ chat_history[-1] = (message, response)
256
+ except Exception as e:
257
+ chat_history[-1] = (message, f"生成响应时出错: {str(e)}")
258
+
259
+ return "", chat_history
260
+
261
  # Create Gradio interface
262
  with gr.Blocks(css=css) as demo:
263
  # Title and description
 
303
  label="Maximum Tokens"
304
  )
305
 
306
+ # 移除Load Model按钮和状态显示
307
+ # load_button = gr.Button("Load Model")
308
+ # model_status = gr.Textbox(label="Model Status", value="No model loaded yet")
309
 
310
  with gr.Column(scale=2):
311
  # Chat interface
 
326
  outputs=model_name
327
  )
328
 
329
+ # 修改提交事件绑定
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
330
  submit_event = user_input.submit(
331
  fn=process_message,
332
+ inputs=[user_input, chatbot, model_series, model_name, temperature, max_tokens],
333
  outputs=[user_input, chatbot]
334
  )
335
 
336
  submit_button.click(
337
  fn=process_message,
338
+ inputs=[user_input, chatbot, model_series, model_name, temperature, max_tokens],
339
  outputs=[user_input, chatbot]
340
  )
341