Tonic commited on
Commit
25aa84d
·
1 Parent(s): 522f297

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -96
app.py CHANGED
@@ -16,7 +16,7 @@ BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
16
  PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
17
  uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(Path(tempfile.gettempdir()) / "gradio")
18
 
19
- def _get_args():
20
  parser = ArgumentParser()
21
  parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
22
  help="Checkpoint name or path, default to %(default)r")
@@ -35,7 +35,7 @@ def _get_args():
35
  args = parser.parse_args()
36
  return args
37
 
38
- def handle_image_submission(_chatbot, task_history, file):
39
  print("handle_image_submission called")
40
  if file is None:
41
  print("No file uploaded")
@@ -49,7 +49,7 @@ def handle_image_submission(_chatbot, task_history, file):
49
  return predict(_chatbot, task_history)
50
 
51
 
52
- def _load_model_tokenizer(args):
53
  model_id = args.checkpoint_path
54
  model_dir = snapshot_download(model_id, revision=args.revision)
55
  tokenizer = AutoTokenizer.from_pretrained(
@@ -75,7 +75,7 @@ def _load_model_tokenizer(args):
75
  return model, tokenizer
76
 
77
 
78
- def _parse_text(text):
79
  lines = text.split("\n")
80
  lines = [line for line in lines if line != ""]
81
  count = 0
@@ -106,7 +106,7 @@ def _parse_text(text):
106
  text = "".join(lines)
107
  return text
108
 
109
- def save_image(image_file, upload_dir):
110
  print("save_image called with:", image_file)
111
  Path(upload_dir).mkdir(parents=True, exist_ok=True)
112
  filename = secrets.token_hex(10) + Path(image_file.name).suffix
@@ -125,104 +125,106 @@ def add_file(history, task_history, file):
125
  task_history = task_history + [((file_path,), None)]
126
  return history, task_history
127
 
128
- def _launch_demo(args, model, tokenizer):
129
- uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
130
- Path(tempfile.gettempdir()) / "gradio"
131
- )
132
- def predict(_chatbot, task_history):
133
- print("predict called")
134
- if not _chatbot:
135
- return _chatbot
136
- chat_query = _chatbot[-1][0]
137
- print("Chat query:", chat_query)
138
-
139
- if isinstance(chat_query, tuple):
140
- query = [{'image': chat_query[0]}]
141
- else:
142
- query = [{'text': _parse_text(chat_query)}]
143
-
144
- print("Query for model:", query)
145
- inputs = tokenizer.from_list_format(query)
146
- tokenized_inputs = tokenizer(inputs, return_tensors='pt')
147
- tokenized_inputs = tokenized_inputs.to(model.device)
148
-
149
- pred = model.generate(**tokenized_inputs)
150
- response = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
151
- print("Model response:", response)
152
- if 'image' in query[0]:
153
- image = tokenizer.draw_bbox_on_latest_picture(response)
154
- if image is not None:
155
- image_path = save_image(image, uploaded_file_dir)
156
- _chatbot[-1] = (chat_query, (image_path,))
157
- else:
158
- _chatbot[-1] = (chat_query, "No image to display.")
159
  else:
160
- _chatbot[-1] = (chat_query, response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  return _chatbot
 
 
 
 
 
 
 
 
 
 
162
 
163
- def save_uploaded_image(image_file, upload_dir):
164
- if image is None:
165
- return None
166
- temp_dir = secrets.token_hex(20)
167
- temp_dir = Path(uploaded_file_dir) / temp_dir
168
- temp_dir.mkdir(exist_ok=True, parents=True)
169
- name = f"tmp{secrets.token_hex(5)}.jpg"
170
- filename = temp_dir / name
171
- image.save(str(filename))
172
- return str(filename)
173
-
174
- def regenerate(_chatbot, task_history):
175
- if not task_history:
176
- return _chatbot
177
- item = task_history[-1]
178
- if item[1] is None:
179
- return _chatbot
180
- task_history[-1] = (item[0], None)
181
- chatbot_item = _chatbot.pop(-1)
182
- if chatbot_item[0] is None:
183
- _chatbot[-1] = (_chatbot[-1][0], None)
184
- else:
185
- _chatbot.append((chatbot_item[0], None))
186
- return predict(_chatbot, task_history)
187
-
188
- def add_text(history, task_history, text):
189
- task_text = text
190
- if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
191
- task_text = text[:-1]
192
- history = history + [(_parse_text(text), None)]
193
- task_history = task_history + [(task_text, None)]
194
- return history, task_history, ""
195
-
196
- def add_file(history, task_history, file):
197
- if file is None:
198
- return history, task_history # Return if no file is uploaded
199
- file_path = file.name
200
- history = history + [((file.name,), None)]
201
- task_history = task_history + [((file.name,), None)]
202
- return history, task_history
203
 
204
- def reset_user_input():
205
- return gr.update(value="")
 
 
 
 
 
 
 
 
206
 
207
- def process_response(response):
208
- response = response.replace("<ref>", "").replace(r"</ref>", "")
209
- response = re.sub(BOX_TAG_PATTERN, "", response)
210
- return response
211
- def process_history_for_model(task_history):
212
- processed_history = []
213
- for query, response in task_history:
214
- if isinstance(query, tuple):
215
- query = {'image': query[0]}
216
- else:
217
- query = {'text': query}
218
- response = response or ""
219
- processed_history.append((query, response))
220
- return processed_history
 
 
 
 
221
 
222
- def reset_state(task_history):
223
- task_history.clear()
224
- return []
225
 
 
 
 
 
226
 
227
  with gr.Blocks() as demo:
228
  gr.Markdown("""# Welcome to Tonic's Qwen-VL-Chat Bot""")
 
16
  PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
17
  uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(Path(tempfile.gettempdir()) / "gradio")
18
 
19
+ def _get_args() -> ArgumentParser:
20
  parser = ArgumentParser()
21
  parser.add_argument("-c", "--checkpoint-path", type=str, default=DEFAULT_CKPT_PATH,
22
  help="Checkpoint name or path, default to %(default)r")
 
35
  args = parser.parse_args()
36
  return args
37
 
38
+ def handle_image_submission(_chatbot, task_history, file) -> tuple:
39
  print("handle_image_submission called")
40
  if file is None:
41
  print("No file uploaded")
 
49
  return predict(_chatbot, task_history)
50
 
51
 
52
+ def _load_model_tokenizer(args) -> tuple:
53
  model_id = args.checkpoint_path
54
  model_dir = snapshot_download(model_id, revision=args.revision)
55
  tokenizer = AutoTokenizer.from_pretrained(
 
75
  return model, tokenizer
76
 
77
 
78
+ def _parse_text(text: str) -> str:
79
  lines = text.split("\n")
80
  lines = [line for line in lines if line != ""]
81
  count = 0
 
106
  text = "".join(lines)
107
  return text
108
 
109
+ def save_image(image_file, upload_dir: str) -> str:
110
  print("save_image called with:", image_file)
111
  Path(upload_dir).mkdir(parents=True, exist_ok=True)
112
  filename = secrets.token_hex(10) + Path(image_file.name).suffix
 
125
  task_history = task_history + [((file_path,), None)]
126
  return history, task_history
127
 
128
+
129
+ def predict(_chatbot, task_history) -> list:
130
+ print("predict called")
131
+ if not _chatbot:
132
+ return _chatbot
133
+ chat_query = _chatbot[-1][0]
134
+ print("Chat query:", chat_query)
135
+
136
+ if isinstance(chat_query, tuple):
137
+ query = [{'image': chat_query[0]}]
138
+ else:
139
+ query = [{'text': _parse_text(chat_query)}]
140
+
141
+ print("Query for model:", query)
142
+ inputs = tokenizer.from_list_format(query)
143
+ tokenized_inputs = tokenizer(inputs, return_tensors='pt')
144
+ tokenized_inputs = tokenized_inputs.to(model.device)
145
+
146
+ pred = model.generate(**tokenized_inputs)
147
+ response = tokenizer.decode(pred.cpu()[0], skip_special_tokens=False)
148
+ print("Model response:", response)
149
+ if 'image' in query[0]:
150
+ image = tokenizer.draw_bbox_on_latest_picture(response)
151
+ if image is not None:
152
+ image_path = save_image(image, uploaded_file_dir)
153
+ _chatbot[-1] = (chat_query, (image_path,))
 
 
 
 
 
154
  else:
155
+ _chatbot[-1] = (chat_query, "No image to display.")
156
+ else:
157
+ _chatbot[-1] = (chat_query, response)
158
+ return _chatbot
159
+
160
+ def save_uploaded_image(image_file, upload_dir):
161
+ if image is None:
162
+ return None
163
+ temp_dir = secrets.token_hex(20)
164
+ temp_dir = Path(uploaded_file_dir) / temp_dir
165
+ temp_dir.mkdir(exist_ok=True, parents=True)
166
+ name = f"tmp{secrets.token_hex(5)}.jpg"
167
+ filename = temp_dir / name
168
+ image.save(str(filename))
169
+ return str(filename)
170
+
171
+ def regenerate(_chatbot, task_history) -> list:
172
+ if not task_history:
173
  return _chatbot
174
+ item = task_history[-1]
175
+ if item[1] is None:
176
+ return _chatbot
177
+ task_history[-1] = (item[0], None)
178
+ chatbot_item = _chatbot.pop(-1)
179
+ if chatbot_item[0] is None:
180
+ _chatbot[-1] = (_chatbot[-1][0], None)
181
+ else:
182
+ _chatbot.append((chatbot_item[0], None))
183
+ return predict(_chatbot, task_history)
184
 
185
+ def add_text(history, task_history, text) -> tuple:
186
+ task_text = text
187
+ if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
188
+ task_text = text[:-1]
189
+ history = history + [(_parse_text(text), None)]
190
+ task_history = task_history + [(task_text, None)]
191
+ return history, task_history, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
192
 
193
+ def add_file(history, task_history, file):
194
+ if file is None:
195
+ return history, task_history # Return if no file is uploaded
196
+ file_path = file.name
197
+ history = history + [((file.name,), None)]
198
+ task_history = task_history + [((file.name,), None)]
199
+ return history, task_history
200
+
201
+ def reset_user_input():
202
+ return gr.update(value="")
203
 
204
+ def process_response(response: str) -> str:
205
+ response = response.replace("<ref>", "").replace(r"</ref>", "")
206
+ response = re.sub(BOX_TAG_PATTERN, "", response)
207
+ return response
208
+ def process_history_for_model(task_history) -> list:
209
+ processed_history = []
210
+ for query, response in task_history:
211
+ if isinstance(query, tuple):
212
+ query = {'image': query[0]}
213
+ else:
214
+ query = {'text': query}
215
+ response = response or ""
216
+ processed_history.append((query, response))
217
+ return processed_history
218
+
219
+ def reset_state(task_history) -> list:
220
+ task_history.clear()
221
+ return []
222
 
 
 
 
223
 
224
+ def _launch_demo(args, model, tokenizer):
225
+ uploaded_file_dir = os.environ.get("GRADIO_TEMP_DIR") or str(
226
+ Path(tempfile.gettempdir()) / "gradio"
227
+ )
228
 
229
  with gr.Blocks() as demo:
230
  gr.Markdown("""# Welcome to Tonic's Qwen-VL-Chat Bot""")