shravan123321 commited on
Commit
1dce535
·
verified ·
1 Parent(s): 65f793d

Update MVPR.py

Browse files
Files changed (1) hide show
  1. MVPR.py +95 -0
MVPR.py CHANGED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from PIL import Image
4
+ import cv2
5
+ from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
6
+ import gradio as gr
7
+
8
+ # Initialize the SegFormer model for segmentation
9
+ segformer_processor = SegformerImageProcessor.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
10
+ segformer_model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b0-finetuned-ade-512-512")
11
+
12
+ # Function to segment the person in the image
13
+ def segment_person(image_input):
14
+ # Convert input image (numpy array in RGB) to PIL Image
15
+ image = Image.fromarray(image_input).convert("RGB")
16
+ original_width, original_height = image.size
17
+
18
+ # Resize image to 512x512 for the model
19
+ model_input = image.resize((512, 512), Image.Resampling.LANCZOS)
20
+
21
+ # Prepare the image for SegFormer
22
+ inputs = segformer_processor(images=model_input, return_tensors="pt")
23
+
24
+ # Perform inference
25
+ with torch.no_grad():
26
+ outputs = segformer_model(**inputs)
27
+ logits = outputs.logits
28
+
29
+ # Upsample logits to 512x512
30
+ upsampled_logits = torch.nn.functional.interpolate(
31
+ logits, size=(512, 512), mode="bilinear", align_corners=False
32
+ )
33
+
34
+ # Get the predicted segmentation mask (person class = 12 in ADE20K dataset)
35
+ person_class_id = 12
36
+ predicted_mask = upsampled_logits.argmax(dim=1)[0] # Shape: (512, 512)
37
+ binary_mask = (predicted_mask == person_class_id).cpu().numpy() # Boolean mask
38
+
39
+ # Post-process the mask
40
+ mask_uint8 = (binary_mask * 255).astype(np.uint8)
41
+ kernel = np.ones((5, 5), np.uint8)
42
+ mask_cleaned = cv2.morphologyEx(mask_uint8, cv2.MORPH_CLOSE, kernel, iterations=2)
43
+ mask_cleaned = cv2.morphologyEx(mask_cleaned, cv2.MORPH_OPEN, kernel, iterations=2)
44
+ mask_smoothed = cv2.GaussianBlur(mask_cleaned, (7, 7), 0)
45
+ _, mask_final = cv2.threshold(mask_smoothed, 127, 255, cv2.THRESH_BINARY)
46
+
47
+ # Resize mask back to original dimensions
48
+ mask_pil = Image.fromarray(mask_final)
49
+ mask_resized = mask_pil.resize((original_width, original_height), Image.Resampling.LANCZOS)
50
+ mask_array = np.array(mask_resized) > 0 # Boolean mask
51
+
52
+ return mask_array
53
+
54
+ # Function to apply background blur
55
+ def blur_background(image_input, blur_strength):
56
+ # Ensure image is in numpy array format (RGB)
57
+ image_array = np.array(image_input)
58
+
59
+ # Segment the person
60
+ mask = segment_person(image_array)
61
+
62
+ # Apply Gaussian blur to the entire image
63
+ sigma = blur_strength
64
+ blurred_image = cv2.GaussianBlur(image_array, (0, 0), sigmaX=sigma, sigmaY=sigma)
65
+
66
+ # Composite the original foreground with the blurred background
67
+ mask_3d = mask[:, :, np.newaxis] # Add channel dimension for broadcasting
68
+ result = np.where(mask_3d, image_array, blurred_image).astype(np.uint8)
69
+
70
+ return result
71
+
72
+ # Gradio interface function
73
+ def gradio_interface(image, blur_strength):
74
+ if image is None:
75
+ raise ValueError("Please upload an image.")
76
+
77
+ # Process the image
78
+ output_image = blur_background(image, blur_strength)
79
+
80
+ return output_image
81
+
82
+ # Create the Gradio app
83
+ app = gr.Interface(
84
+ fn=gradio_interface,
85
+ inputs=[
86
+ gr.Image(type="numpy", label="Upload Image"),
87
+ gr.Slider(minimum=1, maximum=25, value=10, step=1, label="Blur Strength (Sigma)")
88
+ ],
89
+ outputs=gr.Image(type="numpy", label="Output Image"),
90
+ title="Person Segmentation and Background Blur",
91
+ description="Upload an image to segment the person and blur the background. Adjust the blur strength using the slider."
92
+ )
93
+
94
+ # Launch the app
95
+ app.launch()