apexmin's picture
Fix depth map not shown properly
5ce954e
import gradio as gr
import torch
import numpy as np
import cv2
from PIL import Image
import matplotlib.pyplot as plt
from transformers import AutoImageProcessor, AutoModelForDepthEstimation
from io import BytesIO
# Load models
image_processor = AutoImageProcessor.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
model = AutoModelForDepthEstimation.from_pretrained("depth-anything/Depth-Anything-V2-Small-hf")
def process_image(image, total_degrade_steps=15):
# Convert to PIL if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
# Standardize size to 512x512
image = image.resize((512, 512), Image.LANCZOS)
# Prepare image for the model
inputs = image_processor(images=image.convert('RGB'), return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
predicted_depth = outputs.predicted_depth
# Interpolate to original size
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False,
)
print(f'total_degrade_steps {total_degrade_steps}')
# Normalize depth map to [0, 1]
normalized_depth = (prediction - prediction.min()) / (prediction.max() - prediction.min())
normalized_depth = normalized_depth.squeeze().detach().cpu().numpy()
# Convert original image to numpy array
image_np = np.array(image)
# Create a visualization of the depth map
depth_visualization = (normalized_depth * 255).astype(np.uint8)
depth_image = Image.fromarray(depth_visualization)
# Create a copy of the original image to store the result
result = np.copy(image_np)
# Apply variable blur based on depth
for i in range(total_degrade_steps):
sigma = i * 2 + 1
print(f'sigma: {sigma}')
interval = 0.9 / total_degrade_steps
closer = 0.9 - (i * interval)
further = 0.9 - ((i + 1) * interval)
mask = (normalized_depth > further) & (normalized_depth <= closer)
print(f'closer: {closer}, further: {further}')
if np.any(mask):
try: # Apply Gaussian blur with current kernel size
blurred = cv2.GaussianBlur(image_np, (sigma, sigma), 0)
# # Copy blurred pixels to the result where mask is True
# mask_3d = np.stack([mask, mask, mask], axis=2) if len(image_np.shape) == 3 else mask
# result = np.where(mask_3d, blurred, result)
mask_3d = np.stack([mask, mask, mask], axis=2)
result[mask_3d] = blurred[mask_3d]
except Exception as e:
print(f"Error applying blur with kernel size {sigma}: {e}")
continue
# Convert result back to PIL Image
result_image = Image.fromarray(result.astype(np.uint8))
print(f'result_image size {result_image.size}')
# # Create side-by-side comparison
# combined_width = image.width * 2
# combined_height = image.height
# combined_image = Image.new('RGB', (combined_width, combined_height))
# combined_image.paste(image, (0, 0))
# combined_image.paste(result_image, (image.width, 0))
return depth_image, result_image
# Create Gradio interface
with gr.Blocks(title="Depth-Based Blur Effect") as demo:
gr.Markdown("# Depth-Based Blur Effect")
gr.Markdown("This app applies variable Gaussian blur to images based on depth estimation. Objects farther from the camera appear more blurred, while closer objects remain sharper.")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image")
total_steps = gr.Slider(minimum=5, maximum=20, value=15, step=1, label="Total Blur Levels")
# show_depth = gr.Checkbox(value=True, label="Show Depth Map")
submit_btn = gr.Button("Apply Depth-Based Blur")
with gr.Column():
depth_map = gr.Image(type="pil", label="Depth Map") # Added format="png"
output_image = gr.Image(type="numpy", label="Result (Original | Blurred)")
submit_btn.click(
process_image,
inputs=[input_image, total_steps],
outputs=[depth_map, output_image]
)
gr.Examples(
examples=[
["assets/sample.jpg"],
],
inputs=input_image
)
gr.Markdown("""
## How it works
1. The app uses the Depth-Anything-V2-Small model to estimate depth in the image
2. Depth values are normalized to a range of 0-1
3. A variable Gaussian blur is applied based on depth values
4. Objects farther from the camera (higher depth values) receive stronger blur
5. Objects closer to the camera (lower depth values) remain sharper
This creates a realistic depth-of-field effect similar to what's seen in photography.
""")
# Launch the app
demo.launch()