drhead commited on
Commit
2d4900f
·
verified ·
1 Parent(s): 20d5713

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -4
app.py CHANGED
@@ -242,13 +242,16 @@ def create_cam_visualization_pil(image_pil, cam, alpha=0.6, vis_threshold=0.2):
242
  if cam is None:
243
  return image_pil
244
  w, h = image_pil.size
 
245
 
246
  # Resize CAM to match image
247
- cam_resized = np.array(Image.fromarray(cam).resize((w, h), resample=Image.Resampling.BILINEAR))
248
 
249
  # Normalize CAM to [0, 1]
250
  cam_norm = (cam_resized - cam_resized.min()) / (np.ptp(cam_resized) + 1e-8)
251
 
 
 
252
  # Create heatmap using matplotlib colormap
253
  colormap = cm.get_cmap('jet')
254
  cam_colored = colormap(cam_norm)[:, :, :3] # RGB
@@ -304,7 +307,7 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
304
  fn=create_tags,
305
  inputs=[threshold_slider, sorted_tag_score_state],
306
  outputs=[tag_string, label_box],
307
- show_progress='none'
308
  )
309
 
310
  label_box.select(
@@ -318,14 +321,14 @@ with gr.Blocks(css=".output-class { display: none; }") as demo:
318
  fn=create_cam_visualization_pil,
319
  inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
320
  outputs=[image_input],
321
- show_progress='none'
322
  )
323
 
324
  alpha_slider.input(
325
  fn=create_cam_visualization_pil,
326
  inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
327
  outputs=[image_input],
328
- show_progress='none'
329
  )
330
 
331
  if __name__ == "__main__":
 
242
  if cam is None:
243
  return image_pil
244
  w, h = image_pil.size
245
+ size = max(w, h)
246
 
247
  # Resize CAM to match image
248
+ cam_resized = np.array(Image.fromarray(cam).resize((size, size), resample=Image.Resampling.BILINEAR))
249
 
250
  # Normalize CAM to [0, 1]
251
  cam_norm = (cam_resized - cam_resized.min()) / (np.ptp(cam_resized) + 1e-8)
252
 
253
+ cam_norm = transforms.CenterCrop((h, w))(cam_norm)
254
+
255
  # Create heatmap using matplotlib colormap
256
  colormap = cm.get_cmap('jet')
257
  cam_colored = colormap(cam_norm)[:, :, :3] # RGB
 
307
  fn=create_tags,
308
  inputs=[threshold_slider, sorted_tag_score_state],
309
  outputs=[tag_string, label_box],
310
+ show_progress='hidden'
311
  )
312
 
313
  label_box.select(
 
321
  fn=create_cam_visualization_pil,
322
  inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
323
  outputs=[image_input],
324
+ show_progress='hidden'
325
  )
326
 
327
  alpha_slider.input(
328
  fn=create_cam_visualization_pil,
329
  inputs=[original_image_state, cam_state, alpha_slider, cam_slider],
330
  outputs=[image_input],
331
+ show_progress='hidden'
332
  )
333
 
334
  if __name__ == "__main__":