Tonic commited on
Commit
008a02f
·
unverified ·
1 Parent(s): 76a9fd4

add sliders

Browse files
Files changed (1) hide show
  1. app.py +33 -22
app.py CHANGED
@@ -131,6 +131,7 @@ def plot_bbox(image, data, use_quad_boxes=False):
131
  plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
132
 
133
  ax.axis('off')
 
134
  return fig
135
 
136
  def draw_ocr_bboxes(image, prediction):
@@ -145,6 +146,7 @@ def draw_ocr_bboxes(image, prediction):
145
  "{}".format(label),
146
  align="right",
147
  fill=color)
 
148
  return image
149
 
150
  def draw_bounding_boxes(image, quad_boxes, labels, color=(0, 255, 0), thickness=2):
@@ -161,12 +163,7 @@ def draw_bounding_boxes(image, quad_boxes, labels, color=(0, 255, 0), thickness=
161
 
162
  def process_image(image, task):
163
  prompt = TASK_PROMPTS[task]
164
- # # Print the inputs for debugging
165
- # print(f"\n--- Processing Task: {task} ---")
166
- # print(f"Prompt: {prompt}")
167
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
168
- # # Print the input tensors for debugging
169
- # print(f"Model Input: {inputs}")
170
  generated_ids = model.generate(
171
  **inputs,
172
  max_new_tokens=1024,
@@ -175,24 +172,38 @@ def process_image(image, task):
175
  )
176
 
177
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
178
- # # Print the raw generated output for debugging
179
- # print(f"Raw Model Output: {generated_text}")
180
  parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
181
- # # Print the parsed answer for debugging
182
- # print(f"Parsed Answer: {parsed_answer}")
183
  return parsed_answer
184
 
185
 
186
- def main_process(image, task):
187
- result = process_image(image, task)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
 
 
 
 
 
189
  if task in IMAGE_TASKS:
190
  if task == "📸✍🏻OCR with Region":
191
  fig = plot_bbox(image, result.get('<OCR_WITH_REGION>', {}), use_quad_boxes=True)
192
  output_image = fig_to_pil(fig)
193
  text_output = result.get('<OCR_WITH_REGION>', {}).get('recognized_text', 'No text found')
194
- # # Debugging: Print the recognized text
195
- # print(f"Recognized Text: {text_output}")
196
  return output_image, gr.update(visible=True), text_output, gr.update(visible=False)
197
  else:
198
  fig = plot_bbox(image, result.get(TASK_PROMPTS[task], {}))
@@ -201,7 +212,6 @@ def main_process(image, task):
201
  else:
202
  return None, gr.update(visible=False), str(result), gr.update(visible=True)
203
 
204
-
205
  def reset_outputs():
206
  return None, gr.update(visible=False), None, gr.update(visible=True)
207
 
@@ -224,20 +234,21 @@ with gr.Blocks(title="Tonic's 🙏🏻PLeIAs/📸📈✍🏻Florence-PDF") as if
224
  image_input = gr.Image(type="pil", label="Input Image")
225
  task_dropdown = gr.Dropdown(list(TASK_PROMPTS.keys()), label="Task", value="✍🏻Caption")
226
  with gr.Row():
227
- submit_button = gr.Button("Process")
228
- reset_button = gr.Button("Reset")
 
 
 
 
 
 
229
  with gr.Column(scale=1):
230
  output_image = gr.Image(label="🙏🏻PLeIAs/📸📈✍🏻Florence-PDF", visible=False)
231
  output_text = gr.Textbox(label="🙏🏻PLeIAs/📸📈✍🏻Florence-PDF", visible=True)
232
-
233
- def process_and_update(image, task):
234
- if image is None:
235
- return None, gr.update(visible=False), "Please upload an image first.", gr.update(visible=True)
236
- return main_process(image, task)
237
 
238
  submit_button.click(
239
  fn=process_and_update,
240
- inputs=[image_input, task_dropdown],
241
  outputs=[output_image, output_image, output_text, output_text]
242
  )
243
 
 
131
  plt.text(x1, y1, label, color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
132
 
133
  ax.axis('off')
134
+
135
  return fig
136
 
137
  def draw_ocr_bboxes(image, prediction):
 
146
  "{}".format(label),
147
  align="right",
148
  fill=color)
149
+
150
  return image
151
 
152
  def draw_bounding_boxes(image, quad_boxes, labels, color=(0, 255, 0), thickness=2):
 
163
 
164
  def process_image(image, task):
165
  prompt = TASK_PROMPTS[task]
 
 
 
166
  inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
 
 
167
  generated_ids = model.generate(
168
  **inputs,
169
  max_new_tokens=1024,
 
172
  )
173
 
174
  generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
 
 
175
  parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
176
+
 
177
  return parsed_answer
178
 
179
 
180
+ def main_process(image, task, top_k, top_p, repetition_penalty, num_beams, max_tokens):
181
+ prompt = TASK_PROMPTS[task]
182
+ inputs = processor(text=prompt, images=image, return_tensors="pt").to(device, torch_dtype)
183
+ generated_ids = model.generate(
184
+ **inputs,
185
+ max_new_tokens=max_tokens,
186
+ num_beams=num_beams,
187
+ do_sample=True,
188
+ top_k=top_k,
189
+ top_p=top_p,
190
+ repetition_penalty=repetition_penalty
191
+ )
192
+
193
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
194
+ parsed_answer = processor.post_process_generation(generated_text, task=prompt, image_size=(image.width, image.height))
195
+ return parsed_answer
196
 
197
+ def process_and_update(image, task, top_k, top_p, repetition_penalty, num_beams, max_tokens):
198
+ if image is None:
199
+ return None, gr.update(visible=False), "Please upload an image first.", gr.update(visible=True)
200
+ result = main_process(image, task, top_k, top_p, repetition_penalty, num_beams, max_tokens)
201
+
202
  if task in IMAGE_TASKS:
203
  if task == "📸✍🏻OCR with Region":
204
  fig = plot_bbox(image, result.get('<OCR_WITH_REGION>', {}), use_quad_boxes=True)
205
  output_image = fig_to_pil(fig)
206
  text_output = result.get('<OCR_WITH_REGION>', {}).get('recognized_text', 'No text found')
 
 
207
  return output_image, gr.update(visible=True), text_output, gr.update(visible=False)
208
  else:
209
  fig = plot_bbox(image, result.get(TASK_PROMPTS[task], {}))
 
212
  else:
213
  return None, gr.update(visible=False), str(result), gr.update(visible=True)
214
 
 
215
  def reset_outputs():
216
  return None, gr.update(visible=False), None, gr.update(visible=True)
217
 
 
234
  image_input = gr.Image(type="pil", label="Input Image")
235
  task_dropdown = gr.Dropdown(list(TASK_PROMPTS.keys()), label="Task", value="✍🏻Caption")
236
  with gr.Row():
237
+ submit_button = gr.Button("📸📈✍🏻Process")
238
+ reset_button = gr.Button("♻️Reset")
239
+ with gr.Accordion("🧪Advanced Settings", open=False):
240
+ top_k = gr.Slider(minimum=1, maximum=100, value=50, step=1, label="Top-k")
241
+ top_p = gr.Slider(minimum=0.0, maximum=1.0, value=1.0, step=0.01, label="Top-p")
242
+ repetition_penalty = gr.Slider(minimum=1.0, maximum=2.0, value=1.0, step=0.01, label="Repetition Penalty")
243
+ num_beams = gr.Slider(minimum=1, maximum=6, value=3, step=1, label="Number of Beams")
244
+ max_tokens = gr.Slider(minimum=1, maximum=1024, value=1000, step=1, label="Max Tokens")
245
  with gr.Column(scale=1):
246
  output_image = gr.Image(label="🙏🏻PLeIAs/📸📈✍🏻Florence-PDF", visible=False)
247
  output_text = gr.Textbox(label="🙏🏻PLeIAs/📸📈✍🏻Florence-PDF", visible=True)
 
 
 
 
 
248
 
249
  submit_button.click(
250
  fn=process_and_update,
251
+ inputs=[image_input, task_dropdown, top_k, top_p, repetition_penalty, num_beams, max_tokens],
252
  outputs=[output_image, output_image, output_text, output_text]
253
  )
254