tsi-org commited on
Commit
d709286
ยท
1 Parent(s): e276da4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -41
app.py CHANGED
@@ -128,48 +128,47 @@ def flag_last_response(state, model_selector, request: gr.Request):
128
  return ("",) + (disable_btn,) * 3
129
 
130
 
131
- def regenerate(state, image_process_mode, request: gr.Request):
132
  logger.info(f"regenerate. ip: {request.client.host}")
133
  state.messages[-1][-1] = None
134
  prev_human_msg = state.messages[-2]
135
  if type(prev_human_msg[1]) in (tuple, list):
136
- prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode)
137
  state.skip_next = False
138
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
139
 
140
 
141
  def clear_history(request: gr.Request):
142
  logger.info(f"clear_history. ip: {request.client.host}")
143
  state = default_conversation.copy()
144
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
145
 
146
 
147
- def add_text(state, text, image, image_process_mode, request: gr.Request):
148
  logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
149
- if len(text) <= 0 and image is None:
150
  state.skip_next = True
151
- return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 5
152
  if args.moderate:
153
  flagged = violates_moderation(text)
154
  if flagged:
155
  state.skip_next = True
156
- return (state, state.to_gradio_chatbot(), moderation_msg, None) + (
157
  no_change_btn,
158
  ) * 5
159
 
160
  text = text[:1536] # Hard cut-off
161
- if image is not None:
162
  text = text[:1200] # Hard cut-off for images
163
  if "<image>" not in text:
164
- # text = '<Image><image></Image>' + text
165
  text = text + "\n<image>"
166
- text = (text, image, image_process_mode)
167
  if len(state.get_images(return_pil=True)) > 0:
168
  state = default_conversation.copy()
169
  state.append_message(state.roles[0], text)
170
  state.append_message(state.roles[1], None)
171
  state.skip_next = False
172
- return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 5
173
 
174
 
175
  def http_bot(
@@ -180,12 +179,10 @@ def http_bot(
180
  model_name = model_selector
181
 
182
  if state.skip_next:
183
- # This generate call is skipped due to invalid inputs
184
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
185
  return
186
 
187
  if len(state.messages) == state.offset + 2:
188
- # First round of conversation
189
  if "llava" in model_name.lower():
190
  if "llama-2" in model_name.lower():
191
  template_name = "llava_llama_2"
@@ -222,7 +219,6 @@ def http_bot(
222
  new_state.append_message(new_state.roles[1], None)
223
  state = new_state
224
 
225
- # Query worker address
226
  controller_url = args.controller_url
227
  ret = requests.post(
228
  controller_url + "/get_worker_address", json={"model": model_name}
@@ -230,7 +226,6 @@ def http_bot(
230
  worker_addr = ret.json()["address"]
231
  logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
232
 
233
- # No available worker
234
  if worker_addr == "":
235
  state.messages[-1][-1] = server_error_msg
236
  yield (
@@ -244,7 +239,6 @@ def http_bot(
244
  )
245
  return
246
 
247
- # Construct prompt
248
  prompt = state.get_prompt()
249
 
250
  all_images = state.get_images(return_pil=True)
@@ -258,7 +252,6 @@ def http_bot(
258
  os.makedirs(os.path.dirname(filename), exist_ok=True)
259
  image.save(filename)
260
 
261
- # Make requests
262
  pload = {
263
  "model": model_name,
264
  "prompt": prompt,
@@ -278,7 +271,6 @@ def http_bot(
278
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
279
 
280
  try:
281
- # Stream output
282
  response = requests.post(
283
  worker_addr + "/worker_generate_stream",
284
  headers=headers,
@@ -339,17 +331,13 @@ def http_bot(
339
  title_markdown = """
340
  # ๐ŸŒ‹ AI Tutor Vision: Large Language and Vision Assistant
341
  [[website]](https://myapps.ai) [[Paper]](https://arxiv.org/abs/2304.08485) [[Model]](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)
342
-
343
  ONLY WORKS WITH GPU!
344
-
345
  You can load the model with 4-bit or 8-bit quantization to make it fit in smaller hardwares. Setting the environment variable `bits` to control the quantization.
346
  *Note: 8-bit seems to be slower than both 4-bit/16-bit. Although it has enough VRAM to support 8-bit, until we figure out the inference speed issue, we recommend 4-bit for A10G for the best efficiency.*
347
-
348
  Recommended configurations:
349
  | Hardware | T4-Small (16G) | A10G-Small (24G) | A100-Large (40G) |
350
  |-------------------|-----------------|------------------|------------------|
351
  | **Bits** | 4 (default) | 4 | 16 |
352
-
353
  """
354
 
355
  tos_markdown = """
@@ -367,11 +355,9 @@ The service is a research preview intended for non-commercial use only, subject
367
  """
368
 
369
  block_css = """
370
-
371
  #buttons button {
372
  min-width: min(120px,100%);
373
  }
374
-
375
  """
376
 
377
 
@@ -398,8 +384,15 @@ def build_demo(embed_mode):
398
  container=False,
399
  )
400
 
401
- imagebox = gr.Image(type="pil")
402
- image_process_mode = gr.Radio(
 
 
 
 
 
 
 
403
  ["Crop", "Resize", "Pad", "Default"],
404
  value="Default",
405
  label="Preprocess for non-square image",
@@ -418,7 +411,7 @@ def build_demo(embed_mode):
418
  "What are the things I should be cautious about when I visit here?",
419
  ],
420
  ],
421
- inputs=[imagebox, textbox],
422
  )
423
 
424
  with gr.Accordion("Parameters", open=False) as parameter_row:
@@ -456,13 +449,11 @@ def build_demo(embed_mode):
456
  textbox.render()
457
  with gr.Column(scale=1, min_width=50):
458
  submit_btn = gr.Button(
459
- value="Send", variant="primary", interactive=False
460
- )
461
  with gr.Row(elem_id="buttons") as button_row:
462
  upvote_btn = gr.Button(value="๐Ÿ‘ Upvote", interactive=False)
463
  downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=False)
464
  flag_btn = gr.Button(value="โš ๏ธ Flag", interactive=False)
465
- # stop_btn = gr.Button(value="โน๏ธ Stop Generation", interactive=False)
466
  regenerate_btn = gr.Button(value="๐Ÿ”„ Regenerate", interactive=False)
467
  clear_btn = gr.Button(value="๐Ÿ—‘๏ธ Clear history", interactive=False)
468
 
@@ -471,7 +462,6 @@ def build_demo(embed_mode):
471
  gr.Markdown(learn_more_markdown)
472
  url_params = gr.JSON(visible=False)
473
 
474
- # Register listeners
475
  btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
476
  upvote_btn.click(
477
  upvote_last_response,
@@ -490,21 +480,21 @@ def build_demo(embed_mode):
490
  )
491
  regenerate_btn.click(
492
  regenerate,
493
- [state, image_process_mode],
494
- [state, chatbot, textbox, imagebox] + btn_list,
495
  ).then(
496
  http_bot,
497
  [state, model_selector, temperature, top_p, max_output_tokens],
498
  [state, chatbot] + btn_list,
499
  )
500
  clear_btn.click(
501
- clear_history, None, [state, chatbot, textbox, imagebox] + btn_list
502
  )
503
 
504
  textbox.submit(
505
  add_text,
506
- [state, textbox, imagebox, image_process_mode],
507
- [state, chatbot, textbox, imagebox] + btn_list,
508
  ).then(
509
  http_bot,
510
  [state, model_selector, temperature, top_p, max_output_tokens],
@@ -512,8 +502,8 @@ def build_demo(embed_mode):
512
  )
513
  submit_btn.click(
514
  add_text,
515
- [state, textbox, imagebox, image_process_mode],
516
- [state, chatbot, textbox, imagebox] + btn_list,
517
  ).then(
518
  http_bot,
519
  [state, model_selector, temperature, top_p, max_output_tokens],
@@ -610,7 +600,6 @@ if __name__ == "__main__":
610
  controller_proc = start_controller()
611
  worker_proc = start_worker(model_path, bits=bits)
612
 
613
- # Wait for worker and controller to start
614
  time.sleep(10)
615
 
616
  exit_status = 0
@@ -623,4 +612,4 @@ if __name__ == "__main__":
623
  worker_proc.kill()
624
  controller_proc.kill()
625
 
626
- sys.exit(exit_status)
 
128
  return ("",) + (disable_btn,) * 3
129
 
130
 
131
+ def regenerate(state, image_process_mode1, image_process_mode2, request: gr.Request):
132
  logger.info(f"regenerate. ip: {request.client.host}")
133
  state.messages[-1][-1] = None
134
  prev_human_msg = state.messages[-2]
135
  if type(prev_human_msg[1]) in (tuple, list):
136
+ prev_human_msg[1] = (*prev_human_msg[1][:2], image_process_mode1, image_process_mode2)
137
  state.skip_next = False
138
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
139
 
140
 
141
  def clear_history(request: gr.Request):
142
  logger.info(f"clear_history. ip: {request.client.host}")
143
  state = default_conversation.copy()
144
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
145
 
146
 
147
+ def add_text(state, text, image1, image2, image_process_mode1, image_process_mode2, request: gr.Request):
148
  logger.info(f"add_text. ip: {request.client.host}. len: {len(text)}")
149
+ if len(text) <= 0 and image1 is None and image2 is None:
150
  state.skip_next = True
151
+ return (state, state.to_gradio_chatbot(), "", None, None) + (no_change_btn,) * 5
152
  if args.moderate:
153
  flagged = violates_moderation(text)
154
  if flagged:
155
  state.skip_next = True
156
+ return (state, state.to_gradio_chatbot(), moderation_msg, None, None) + (
157
  no_change_btn,
158
  ) * 5
159
 
160
  text = text[:1536] # Hard cut-off
161
+ if image1 is not None or image2 is not None:
162
  text = text[:1200] # Hard cut-off for images
163
  if "<image>" not in text:
 
164
  text = text + "\n<image>"
165
+ text = (text, image1, image2, image_process_mode1, image_process_mode2)
166
  if len(state.get_images(return_pil=True)) > 0:
167
  state = default_conversation.copy()
168
  state.append_message(state.roles[0], text)
169
  state.append_message(state.roles[1], None)
170
  state.skip_next = False
171
+ return (state, state.to_gradio_chatbot(), "", None, None) + (disable_btn,) * 5
172
 
173
 
174
  def http_bot(
 
179
  model_name = model_selector
180
 
181
  if state.skip_next:
 
182
  yield (state, state.to_gradio_chatbot()) + (no_change_btn,) * 5
183
  return
184
 
185
  if len(state.messages) == state.offset + 2:
 
186
  if "llava" in model_name.lower():
187
  if "llama-2" in model_name.lower():
188
  template_name = "llava_llama_2"
 
219
  new_state.append_message(new_state.roles[1], None)
220
  state = new_state
221
 
 
222
  controller_url = args.controller_url
223
  ret = requests.post(
224
  controller_url + "/get_worker_address", json={"model": model_name}
 
226
  worker_addr = ret.json()["address"]
227
  logger.info(f"model_name: {model_name}, worker_addr: {worker_addr}")
228
 
 
229
  if worker_addr == "":
230
  state.messages[-1][-1] = server_error_msg
231
  yield (
 
239
  )
240
  return
241
 
 
242
  prompt = state.get_prompt()
243
 
244
  all_images = state.get_images(return_pil=True)
 
252
  os.makedirs(os.path.dirname(filename), exist_ok=True)
253
  image.save(filename)
254
 
 
255
  pload = {
256
  "model": model_name,
257
  "prompt": prompt,
 
271
  yield (state, state.to_gradio_chatbot()) + (disable_btn,) * 5
272
 
273
  try:
 
274
  response = requests.post(
275
  worker_addr + "/worker_generate_stream",
276
  headers=headers,
 
331
  title_markdown = """
332
  # ๐ŸŒ‹ AI Tutor Vision: Large Language and Vision Assistant
333
  [[website]](https://myapps.ai) [[Paper]](https://arxiv.org/abs/2304.08485) [[Model]](https://github.com/haotian-liu/LLaVA/blob/main/docs/MODEL_ZOO.md)
 
334
  ONLY WORKS WITH GPU!
 
335
  You can load the model with 4-bit or 8-bit quantization to make it fit in smaller hardwares. Setting the environment variable `bits` to control the quantization.
336
  *Note: 8-bit seems to be slower than both 4-bit/16-bit. Although it has enough VRAM to support 8-bit, until we figure out the inference speed issue, we recommend 4-bit for A10G for the best efficiency.*
 
337
  Recommended configurations:
338
  | Hardware | T4-Small (16G) | A10G-Small (24G) | A100-Large (40G) |
339
  |-------------------|-----------------|------------------|------------------|
340
  | **Bits** | 4 (default) | 4 | 16 |
 
341
  """
342
 
343
  tos_markdown = """
 
355
  """
356
 
357
  block_css = """
 
358
  #buttons button {
359
  min-width: min(120px,100%);
360
  }
 
361
  """
362
 
363
 
 
384
  container=False,
385
  )
386
 
387
+ imagebox1 = gr.Image(type="pil")
388
+ imagebox2 = gr.Image(type="pil")
389
+ image_process_mode1 = gr.Radio(
390
+ ["Crop", "Resize", "Pad", "Default"],
391
+ value="Default",
392
+ label="Preprocess for non-square image",
393
+ visible=False,
394
+ )
395
+ image_process_mode2 = gr.Radio(
396
  ["Crop", "Resize", "Pad", "Default"],
397
  value="Default",
398
  label="Preprocess for non-square image",
 
411
  "What are the things I should be cautious about when I visit here?",
412
  ],
413
  ],
414
+ inputs=[imagebox1, textbox, imagebox2],
415
  )
416
 
417
  with gr.Accordion("Parameters", open=False) as parameter_row:
 
449
  textbox.render()
450
  with gr.Column(scale=1, min_width=50):
451
  submit_btn = gr.Button(
452
+ value="Send", variant="primary", interactive=False)
 
453
  with gr.Row(elem_id="buttons") as button_row:
454
  upvote_btn = gr.Button(value="๐Ÿ‘ Upvote", interactive=False)
455
  downvote_btn = gr.Button(value="๐Ÿ‘Ž Downvote", interactive=False)
456
  flag_btn = gr.Button(value="โš ๏ธ Flag", interactive=False)
 
457
  regenerate_btn = gr.Button(value="๐Ÿ”„ Regenerate", interactive=False)
458
  clear_btn = gr.Button(value="๐Ÿ—‘๏ธ Clear history", interactive=False)
459
 
 
462
  gr.Markdown(learn_more_markdown)
463
  url_params = gr.JSON(visible=False)
464
 
 
465
  btn_list = [upvote_btn, downvote_btn, flag_btn, regenerate_btn, clear_btn]
466
  upvote_btn.click(
467
  upvote_last_response,
 
480
  )
481
  regenerate_btn.click(
482
  regenerate,
483
+ [state, image_process_mode1, image_process_mode2],
484
+ [state, chatbot, textbox, imagebox1, imagebox2] + btn_list,
485
  ).then(
486
  http_bot,
487
  [state, model_selector, temperature, top_p, max_output_tokens],
488
  [state, chatbot] + btn_list,
489
  )
490
  clear_btn.click(
491
+ clear_history, None, [state, chatbot, textbox, imagebox1, imagebox2] + btn_list
492
  )
493
 
494
  textbox.submit(
495
  add_text,
496
+ [state, textbox, imagebox1, imagebox2, image_process_mode1, image_process_mode2],
497
+ [state, chatbot, textbox, imagebox1, imagebox2] + btn_list,
498
  ).then(
499
  http_bot,
500
  [state, model_selector, temperature, top_p, max_output_tokens],
 
502
  )
503
  submit_btn.click(
504
  add_text,
505
+ [state, textbox, imagebox1, imagebox2, image_process_mode1, image_process_mode2],
506
+ [state, chatbot, textbox, imagebox1, imagebox2] + btn_list,
507
  ).then(
508
  http_bot,
509
  [state, model_selector, temperature, top_p, max_output_tokens],
 
600
  controller_proc = start_controller()
601
  worker_proc = start_worker(model_path, bits=bits)
602
 
 
603
  time.sleep(10)
604
 
605
  exit_status = 0
 
612
  worker_proc.kill()
613
  controller_proc.kill()
614
 
615
+ sys.exit(exit_status)