nickkun commited on
Commit
d547454
·
verified ·
1 Parent(s): efa322c

Update app.py

Browse files

Added Model Selection

Files changed (1) hide show
  1. app.py +105 -21
app.py CHANGED
@@ -10,11 +10,42 @@ import numpy as np
10
  import requests
11
  import cv2
12
 
13
- # Load models once
14
- print("Loading segmentation model...")
15
- segmentation_model = pipeline("image-segmentation", model="nvidia/segformer-b1-finetuned-cityscapes-1024-1024")
16
- print("Loading depth estimation model...")
17
- depth_estimator = pipeline("depth-estimation", model="Intel/zoedepth-nyu-kitti")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def lens_blur(image, radius):
20
  """
@@ -72,6 +103,10 @@ def process_image(input_image, method, blur_intensity, blur_type):
72
  - output_image: final composited image.
73
  - mask_image: the mask used (binary for segmentation, normalized depth for depth-based).
74
  """
 
 
 
 
75
  # Ensure image is in RGB mode
76
  input_image = input_image.convert("RGB")
77
 
@@ -135,22 +170,71 @@ def process_image(input_image, method, blur_intensity, blur_type):
135
  with gr.Blocks() as demo:
136
  gr.Markdown("## Image Processing App: Segmentation & Depth-based Blur")
137
 
138
- with gr.Row():
139
- with gr.Column():
140
- input_image = gr.Image(label="Input Image", type="pil")
141
- method = gr.Radio(label="Processing Method",
142
- choices=["Segmented Background Blur", "Depth-based Variable Blur"],
143
- value="Segmented Background Blur")
144
- blur_intensity = gr.Slider(label="Blur Intensity (Maximum Blur Radius)", minimum=1, maximum=30, step=1, value=15)
145
- blur_type = gr.Dropdown(label="Blur Type", choices=["Gaussian Blur", "Lens Blur"], value="Gaussian Blur")
146
- run_button = gr.Button("Process Image")
147
- with gr.Column():
148
- output_image = gr.Image(label="Output Image")
149
- mask_output = gr.Image(label="Mask")
150
-
151
- run_button.click(fn=process_image,
152
- inputs=[input_image, method, blur_intensity, blur_type],
153
- outputs=[output_image, mask_output])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  # Launch the app
156
  demo.launch()
 
10
  import requests
11
  import cv2
12
 
13
+ # Dictionary of available segmentation models
14
+ SEGMENTATION_MODELS = {
15
+ "NVIDIA SegFormer (Cityscapes)": "nvidia/segformer-b1-finetuned-cityscapes-1024-1024",
16
+ "NVIDIA SegFormer (ADE20K)": "nvidia/segformer-b0-finetuned-ade-512-512",
17
+ "Facebook MaskFormer (COCO)": "facebook/maskformer-swin-base-ade",
18
+ "OneFormer (COCO)": "shi-labs/oneformer_coco_swin_large",
19
+ "NVIDIA SegFormer (B5)": "nvidia/segformer-b5-finetuned-cityscapes-1024-1024"
20
+ }
21
+
22
+ # Dictionary of available depth estimation models
23
+ DEPTH_MODELS = {
24
+ "Intel ZoeDepth (NYU-KITTI)": "Intel/zoedepth-nyu-kitti",
25
+ "DPT (Large)": "Intel/dpt-large",
26
+ "DPT (Hybrid)": "Intel/dpt-hybrid-midas",
27
+ "GLPDepth": "vinvino02/glpn-nyu"
28
+ }
29
+
30
+ # Initialize model placeholders
31
+ segmentation_model = None
32
+ depth_estimator = None
33
+
34
+ def load_segmentation_model(model_name):
35
+ """Load the selected segmentation model"""
36
+ global segmentation_model
37
+ model_path = SEGMENTATION_MODELS[model_name]
38
+ print(f"Loading segmentation model: {model_path}...")
39
+ segmentation_model = pipeline("image-segmentation", model=model_path)
40
+ return f"Loaded segmentation model: {model_name}"
41
+
42
+ def load_depth_model(model_name):
43
+ """Load the selected depth estimation model"""
44
+ global depth_estimator
45
+ model_path = DEPTH_MODELS[model_name]
46
+ print(f"Loading depth estimation model: {model_path}...")
47
+ depth_estimator = pipeline("depth-estimation", model=model_path)
48
+ return f"Loaded depth model: {model_name}"
49
 
50
  def lens_blur(image, radius):
51
  """
 
103
  - output_image: final composited image.
104
  - mask_image: the mask used (binary for segmentation, normalized depth for depth-based).
105
  """
106
+ # Check if models are loaded
107
+ if segmentation_model is None or depth_estimator is None:
108
+ return input_image, input_image.convert("L")
109
+
110
  # Ensure image is in RGB mode
111
  input_image = input_image.convert("RGB")
112
 
 
170
  with gr.Blocks() as demo:
171
  gr.Markdown("## Image Processing App: Segmentation & Depth-based Blur")
172
 
173
+ with gr.Tab("Model Selection"):
174
+ with gr.Row():
175
+ with gr.Column():
176
+ seg_model_dropdown = gr.Dropdown(
177
+ label="Segmentation Model",
178
+ choices=list(SEGMENTATION_MODELS.keys()),
179
+ value=list(SEGMENTATION_MODELS.keys())[0]
180
+ )
181
+ seg_model_load_btn = gr.Button("Load Segmentation Model")
182
+ seg_model_status = gr.Textbox(label="Status", value="No model loaded")
183
+
184
+ with gr.Column():
185
+ depth_model_dropdown = gr.Dropdown(
186
+ label="Depth Estimation Model",
187
+ choices=list(DEPTH_MODELS.keys()),
188
+ value=list(DEPTH_MODELS.keys())[0]
189
+ )
190
+ depth_model_load_btn = gr.Button("Load Depth Model")
191
+ depth_model_status = gr.Textbox(label="Status", value="No model loaded")
192
+
193
+ with gr.Tab("Image Processing"):
194
+ with gr.Row():
195
+ with gr.Column():
196
+ input_image = gr.Image(label="Input Image", type="pil")
197
+ method = gr.Radio(label="Processing Method",
198
+ choices=["Segmented Background Blur", "Depth-based Variable Blur"],
199
+ value="Segmented Background Blur")
200
+ blur_intensity = gr.Slider(label="Blur Intensity (Maximum Blur Radius)",
201
+ minimum=1, maximum=30, step=1, value=15)
202
+ blur_type = gr.Dropdown(label="Blur Type",
203
+ choices=["Gaussian Blur", "Lens Blur"],
204
+ value="Gaussian Blur")
205
+ run_button = gr.Button("Process Image")
206
+ with gr.Column():
207
+ output_image = gr.Image(label="Output Image")
208
+ mask_output = gr.Image(label="Mask")
209
+
210
+ # Set up event handlers
211
+ seg_model_load_btn.click(
212
+ fn=load_segmentation_model,
213
+ inputs=[seg_model_dropdown],
214
+ outputs=[seg_model_status]
215
+ )
216
+
217
+ depth_model_load_btn.click(
218
+ fn=load_depth_model,
219
+ inputs=[depth_model_dropdown],
220
+ outputs=[depth_model_status]
221
+ )
222
+
223
+ run_button.click(
224
+ fn=process_image,
225
+ inputs=[input_image, method, blur_intensity, blur_type],
226
+ outputs=[output_image, mask_output]
227
+ )
228
+
229
+ # Load default models on startup
230
+ demo.load(
231
+ fn=lambda: (
232
+ load_segmentation_model(list(SEGMENTATION_MODELS.keys())[0]),
233
+ load_depth_model(list(DEPTH_MODELS.keys())[0])
234
+ ),
235
+ inputs=None,
236
+ outputs=[seg_model_status, depth_model_status]
237
+ )
238
 
239
  # Launch the app
240
  demo.launch()