drhead commited on
Commit
02a9646
·
verified ·
1 Parent(s): 2310671

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -18
app.py CHANGED
@@ -276,29 +276,18 @@ def create_cam_visualization_pil(cam, alpha=0.6, vis_threshold=0.2):
276
  # Normalize CAM to [0, 1]
277
  cam_norm = (cam_resized - cam_resized.min()) / (np.ptp(cam_resized) + 1e-8)
278
 
279
- # Apply threshold mask
280
- mask = cam_norm >= vis_threshold
281
-
282
  # Create heatmap using matplotlib colormap
283
  colormap = cm.get_cmap('jet')
284
- heatmap_rgba = colormap(cam_norm) # shape: (H, W, 4), values in [0, 1]
285
- heatmap_rgb = (heatmap_rgba[:, :, :3] * 255).astype(np.uint8)
286
-
287
- # Convert heatmap to PIL image
288
- heatmap_pil = Image.fromarray(heatmap_rgb).convert("RGB")
289
 
290
- # Convert images to NumPy for blending
291
- base_np = np.array(image_pil).astype(np.float32)
292
- heat_np = np.array(heatmap_pil).astype(np.float32)
293
 
294
- # Blend only where mask is True
295
- blended_np = base_np.copy()
296
- blended_np[mask] = base_np[mask] * (1 - alpha) + heat_np[mask] * alpha
297
- blended_np = np.clip(blended_np, 0, 255).astype(np.uint8)
298
 
299
- # Convert back to PIL image
300
- blended_img = Image.fromarray(blended_np)
301
- return blended_img
302
 
303
 
304
  with gr.Blocks(css=".output-class { display: none; }") as demo:
 
276
  # Normalize CAM to [0, 1]
277
  cam_norm = (cam_resized - cam_resized.min()) / (np.ptp(cam_resized) + 1e-8)
278
 
 
 
 
279
  # Create heatmap using matplotlib colormap
280
  colormap = cm.get_cmap('jet')
281
+ cam_colored = colormap(cam_norm)[:, :, :3] # RGB
282
+ cam_alpha = (cam_norm >= vis_threshold).astype(np.float32) * alpha # Alpha mask
 
 
 
283
 
284
+ cam_rgba = np.dstack((cam_colored, cam_alpha)) # Shape: (H, W, 4)
285
+ cam_image = Image.fromarray((cam_rgba * 255).astype(np.uint8), mode="RGBA")
 
286
 
287
+ # Composite over original
288
+ composite = Image.alpha_composite(image_pil, cam_image)
 
 
289
 
290
+ return composite
 
 
291
 
292
 
293
  with gr.Blocks(css=".output-class { display: none; }") as demo: