Vishva007 commited on
Commit
280f1e8
·
verified ·
1 Parent(s): a9ba0ac

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +323 -0
app.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # streamlit_app.py
2
+
3
+ import streamlit as st
4
+ import numpy as np
5
+ from PIL import Image
6
+ import cv2
7
+ import torch
8
+ from transformers import pipeline
9
+ import time
10
+ import os
11
+ from io import BytesIO # <-- IMPORT BytesIO
12
+
13
+ # --- Page Config (MUST BE FIRST st command) ---
14
+ # Set page config early
15
+ st.set_page_config(
16
+ page_title="Depth Blur Studio",
17
+ page_icon="📸",
18
+ layout="wide"
19
+ )
20
+
21
+ # --- Import Custom Class ---
22
+ # Assuming PortraitBlurrer.py is in a subfolder 'Portrait' relative to this script
23
+ try:
24
+ # If PortraitBlurrer is in ./Portrait/Portrait.py
25
+ from Portrait.Portrait import PortraitBlurrer
26
+ except ImportError:
27
+ # Fallback if PortraitBlurrer is in ./PortraitBlurrer.py
28
+ try:
29
+ from PortraitBlurrer import PortraitBlurrer # type: ignore
30
+ # st.warning("Assuming PortraitBlurrer class is in the root directory.") # Optional warning
31
+ except ImportError:
32
+ st.error("Fatal Error: Could not find the PortraitBlurrer class. Please check the file structure and import path.")
33
+ st.stop() # Stop execution if class can't be found
34
+
35
+
36
+ # --- Model Loading (Cached) ---
37
+ @st.cache_resource # Use cache_resource for non-data objects like models/pipelines
38
+ def load_depth_pipeline():
39
+ """Loads the depth estimation pipeline and caches it. Returns tuple (pipeline, device_id)."""
40
+ t_device = 0 if torch.cuda.is_available() else -1
41
+ print(f"Attempting to load model on device: {'GPU (CUDA)' if t_device == 0 else 'CPU'}")
42
+ try:
43
+ # Use default precision (float32)
44
+ t_pipe = pipeline(task="depth-estimation",
45
+ model="depth-anything/Depth-Anything-V2-Large-hf",
46
+ device=t_device)
47
+ print("Depth Anything V2 Large model loaded successfully.")
48
+ return t_pipe, t_device # Return pipeline and device used
49
+ except Exception as e:
50
+ print(f"Error loading model: {e}")
51
+ # Error will be displayed in the main app body after this function returns None
52
+ return None, t_device # Return None for pipe on error
53
+
54
+ # Load the model via the cached function
55
+ pipe, device_used = load_depth_pipeline()
56
+
57
+ # --- Title and Model Status ---
58
+ # Display title and info AFTER attempting model load
59
+ st.title("Depth Blur Studio 📸 (Streamlit)")
60
+ st.markdown(
61
+ "Upload a portrait image. The model will estimate depth and blur the background, keeping the subject sharp."
62
+ "\n*Model: `depth-anything/Depth-Anything-V2-Large-hf`*"
63
+ )
64
+ st.caption(f"_(Using device: {'GPU (CUDA)' if device_used == 0 else 'CPU'})_") # Display device info
65
+
66
+ # Handle model loading failure AFTER potential UI elements like title
67
+ if pipe is None:
68
+ st.error("Error loading depth estimation model. Application cannot proceed.")
69
+ st.stop() # Stop if model loading failed
70
+
71
+
72
+ # --- Processing Function ---
73
+ def process_image_blur(pipeline_obj, input_image_pil, max_blur_ksize, depth_thresh, feather_ksize, sharpen_val):
74
+ """
75
+ Processes the image using the pipeline and PortraitBlurrer.
76
+ Returns tuple: (blurred_pil, depth_pil, mask_pil) or (None, None, None) on failure.
77
+ """
78
+ print("Processing image...")
79
+ processing_start_time = time.time()
80
+
81
+ # 1. Convert PIL Image (RGB) to NumPy array (BGR for OpenCV)
82
+ input_image_np_rgb = np.array(input_image_pil)
83
+ original_bgr_np = cv2.cvtColor(input_image_np_rgb, cv2.COLOR_RGB2BGR)
84
+
85
+ # 2. Perform depth estimation
86
+ try:
87
+ with torch.no_grad(): # Inference only
88
+ depth_output = pipeline_obj(input_image_pil)
89
+ # Ensure depth map is PIL Image
90
+ if isinstance(depth_output, dict) and "depth" in depth_output:
91
+ depth_image_pil = depth_output["depth"]
92
+ if not isinstance(depth_image_pil, Image.Image):
93
+ # Attempt conversion if it's tensor/numpy (specifics might depend on pipeline output)
94
+ # This is a basic attempt; might need refinement based on actual output type
95
+ try:
96
+ depth_data = np.array(depth_image_pil)
97
+ # Normalize if needed (example: scale to 0-255)
98
+ depth_data = cv2.normalize(depth_data, None, 0, 255, cv2.NORM_MINMAX, dtype=cv2.CV_8U)
99
+ depth_image_pil = Image.fromarray(depth_data)
100
+ except Exception as conversion_e:
101
+ print(f"Could not convert depth output to PIL Image: {conversion_e}")
102
+ raise ValueError("Depth estimation did not return a usable PIL Image.")
103
+ else:
104
+ # Handle cases where output might be directly the image or unexpected format
105
+ if isinstance(depth_output, Image.Image):
106
+ depth_image_pil = depth_output
107
+ else:
108
+ raise ValueError(f"Unexpected depth estimation output format: {type(depth_output)}")
109
+
110
+ print("Depth map generated.")
111
+ except Exception as e:
112
+ print(f"Error during depth estimation: {e}")
113
+ st.error(f"Depth estimation failed: {e}") # Show error in UI
114
+ return None, None, None
115
+
116
+ # 3. Initialize Blurrer and Process
117
+ portrait_blurrer = PortraitBlurrer(
118
+ max_blur=int(max_blur_ksize),
119
+ depth_threshold=int(depth_thresh),
120
+ feather_strength=int(feather_ksize),
121
+ sharpen_strength=float(sharpen_val) # Use the passed sharpen value
122
+ )
123
+
124
+ try:
125
+ # process_image returns blurred_bgr, depth_gray, mask_gray
126
+ blurred_bgr_np, refined_depth_np, mask_np = portrait_blurrer.process_image(
127
+ original_bgr_np, depth_image_pil
128
+ )
129
+ except Exception as e:
130
+ print(f"Error during blurring/sharpening: {e}")
131
+ st.error(f"Image processing (blur/sharpen) failed: {e}") # Show error in UI
132
+ return None, None, None
133
+
134
+ # 4. Convert results back to RGB PIL Images for Streamlit display
135
+ blurred_pil = Image.fromarray(cv2.cvtColor(blurred_bgr_np, cv2.COLOR_BGR2RGB))
136
+ # Depth and mask are grayscale numpy, convert directly to PIL
137
+ depth_pil = Image.fromarray(refined_depth_np)
138
+ mask_pil = Image.fromarray(mask_np)
139
+
140
+ processing_end_time = time.time()
141
+ processing_duration = processing_end_time - processing_start_time
142
+ print(f"Processing finished in {processing_duration:.2f} seconds.")
143
+ # Move success message display outside this function, near where results are shown
144
+ # st.success(f"Processing finished in {processing_duration:.2f} seconds.")
145
+
146
+ return blurred_pil, depth_pil, mask_pil, processing_duration # Return duration
147
+
148
+
149
+ # --- Initialize Session State --- (Do this early)
150
+ if 'results' not in st.session_state:
151
+ st.session_state.results = None # Will store tuple (blurred, depth, mask) or None
152
+ if 'original_image_pil' not in st.session_state:
153
+ st.session_state.original_image_pil = None
154
+ if 'processing_error_occurred' not in st.session_state:
155
+ st.session_state.processing_error_occurred = False
156
+ if 'current_filename' not in st.session_state:
157
+ st.session_state.current_filename = None
158
+ if 'last_process_duration' not in st.session_state:
159
+ st.session_state.last_process_duration = None
160
+
161
+
162
+ # --- Sidebar for Controls ---
163
+ with st.sidebar: # Use 'with' notation for clarity
164
+ st.title("Controls")
165
+ uploaded_file = st.file_uploader(
166
+ "Upload Portrait Image",
167
+ type=["jpg", "png", "jpeg"],
168
+ label_visibility="collapsed"
169
+ )
170
+
171
+ # --- Handle New Upload for Instant Display ---
172
+ if uploaded_file is not None:
173
+ # Check if it's a new file by comparing names
174
+ if uploaded_file.name != st.session_state.get('current_filename', None):
175
+ print(f"New file uploaded: {uploaded_file.name}. Loading for display.")
176
+ try:
177
+ # Load the new image immediately
178
+ st.session_state.original_image_pil = Image.open(uploaded_file).convert("RGB")
179
+ # Clear previous results, error state and duration
180
+ st.session_state.results = None
181
+ st.session_state.processing_error_occurred = False
182
+ st.session_state.last_process_duration = None
183
+ # Update the tracked filename
184
+ st.session_state.current_filename = uploaded_file.name
185
+ except Exception as e:
186
+ st.error(f"Error loading image: {e}")
187
+ # Clear states if loading failed
188
+ st.session_state.original_image_pil = None
189
+ st.session_state.results = None
190
+ st.session_state.processing_error_occurred = False
191
+ st.session_state.current_filename = None
192
+ st.session_state.last_process_duration = None
193
+
194
+ elif st.session_state.current_filename is not None:
195
+ # If file uploader is cleared by the user (uploaded_file becomes None)
196
+ print("File upload cleared.")
197
+ st.session_state.original_image_pil = None
198
+ st.session_state.results = None
199
+ st.session_state.processing_error_occurred = False
200
+ st.session_state.current_filename = None
201
+ st.session_state.last_process_duration = None
202
+ # --- End Handle New Upload ---
203
+
204
+
205
+ st.markdown("---") # Separator
206
+ st.markdown("**Adjust Parameters:**")
207
+ slider_max_blur = st.slider("Blur Intensity (Kernel Size)", min_value=3, max_value=101, step=2, value=31)
208
+ slider_depth_thr = st.slider("Subject Depth Threshold (Lower=Closer)", min_value=1, max_value=254, step=1, value=120)
209
+ slider_feather = st.slider("Feathering (Mask Smoothness)", min_value=1, max_value=51, step=2, value=5) # <-- Default changed to 5
210
+ # REMOVED: slider_sharpen = st.slider("Subject Sharpening Strength", min_value=0.0, max_value=2.5, step=0.1, value=1.0)
211
+ st.markdown("---") # Separator
212
+
213
+ # Button to trigger processing - disable if no file *loaded* in session state
214
+ process_button = st.button(
215
+ "Apply Blur",
216
+ type="primary",
217
+ disabled=(st.session_state.original_image_pil is None) # Disable if no original image is loaded
218
+ )
219
+
220
+
221
+ # --- Main Area for Images ---
222
+ col1, col2 = st.columns(2) # Create two columns for Original | Result
223
+
224
+ # --- Handle Processing Trigger ---
225
+ if process_button: # Button is only enabled if original_image_pil exists
226
+ if st.session_state.original_image_pil is not None:
227
+ # Reset error flag on new processing attempt
228
+ st.session_state.processing_error_occurred = False
229
+ # Clear previous results and duration before showing spinner
230
+ st.session_state.results = None
231
+ st.session_state.last_process_duration = None
232
+
233
+ with col2: # Show spinner in the results column
234
+ with st.spinner('Applying blur... This may take a moment...'):
235
+ results_output = process_image_blur(
236
+ pipeline_obj=pipe,
237
+ input_image_pil=st.session_state.original_image_pil, # Use the image from session state
238
+ max_blur_ksize=slider_max_blur,
239
+ depth_thresh=slider_depth_thr,
240
+ feather_ksize=slider_feather,
241
+ sharpen_val=1.0 # <-- Hardcoded sharpen value
242
+ )
243
+
244
+ # Check if processing returned successfully (4 values expected now)
245
+ if results_output is not None and len(results_output) == 4:
246
+ # Unpack results and store duration separately
247
+ blurred_pil, depth_pil, mask_pil, duration = results_output
248
+ st.session_state.results = (blurred_pil, depth_pil, mask_pil) # Store tuple
249
+ st.session_state.last_process_duration = duration
250
+ else:
251
+ # Processing failed (returned None or wrong number of items)
252
+ st.session_state.results = None # Ensure results are None
253
+ st.session_state.processing_error_occurred = True
254
+ st.session_state.last_process_duration = None
255
+
256
+ else:
257
+ # This case should technically not happen due to button disable logic, but good practice
258
+ st.error("No image loaded to process.")
259
+
260
+
261
+ # --- Display Images based on Session State ---
262
+
263
+ # Display Original Image in Column 1 if available
264
+ if st.session_state.original_image_pil is not None:
265
+ col1.image(st.session_state.original_image_pil, caption="Original Image", use_container_width=True)
266
+ else:
267
+ col1.markdown("### Upload an image")
268
+ col1.markdown("Use the sidebar controls to upload your portrait.")
269
+
270
+ # Display Results/Status in Column 2
271
+ if st.session_state.results is not None:
272
+ # Check if the first element (blurred_img) is not None, indicating successful processing within the function
273
+ blurred_img, depth_img, mask_img = st.session_state.results
274
+ if blurred_img is not None:
275
+ # Display success message with duration
276
+ if st.session_state.last_process_duration is not None:
277
+ st.success(f"Processing finished in {st.session_state.last_process_duration:.2f} seconds.")
278
+
279
+ col2.image(blurred_img, caption="Blurred Background Result", use_container_width=True)
280
+
281
+ # --- ADD DOWNLOAD BUTTON ---
282
+ # 1. Convert PIL Image to Bytes
283
+ buf = BytesIO()
284
+ blurred_img.save(buf, format="PNG") # Save image to buffer in PNG format
285
+ byte_im = buf.getvalue() # Get bytes from buffer
286
+
287
+ # 2. Add Download Button
288
+ col2.download_button(
289
+ label="Download Blurred Image",
290
+ data=byte_im,
291
+ file_name=f"blurred_{st.session_state.current_filename or 'result'}.png", # Suggest filename based on original
292
+ mime="image/png" # Set the MIME type for PNG
293
+ )
294
+ # --- END DOWNLOAD BUTTON ---
295
+
296
+ # Optionally display depth and mask below the main images or in expanders
297
+ with st.expander("Show Details (Depth Map & Mask)"):
298
+ # Use columns inside expander for better layout if needed
299
+ exp_col1, exp_col2 = st.columns(2)
300
+ exp_col1.image(depth_img, caption="Refined Depth Map", use_container_width=True)
301
+ exp_col2.image(mask_img, caption="Subject Mask", use_container_width=True)
302
+ else:
303
+ # This case might occur if results tuple was somehow malformed, treat as error
304
+ st.session_state.processing_error_occurred = True # Mark as error if blurred_img is None but results tuple exists
305
+ col2.error("An unexpected issue occurred during processing. Please check logs or try again.")
306
+
307
+
308
+ # Handle explicit error state OR "Ready to Process" state OR default state
309
+ if st.session_state.processing_error_occurred:
310
+ # Display specific error message if processing failed after button press
311
+ # The error might already be shown by st.error inside process_image_blur,
312
+ # but this provides a fallback message in col2.
313
+ col2.warning("Image processing failed. Check messages above or terminal logs.")
314
+
315
+ elif st.session_state.original_image_pil is not None and st.session_state.results is None:
316
+ # If file is uploaded/loaded but not processed yet (and no error occurred)
317
+ col2.markdown("### Ready to Process")
318
+ col2.markdown("Adjust parameters in the sidebar (if needed) and click **Apply Blur**.")
319
+
320
+ elif st.session_state.original_image_pil is None:
321
+ # Default state when no file is uploaded/loaded and nothing processed
322
+ col2.markdown("### Results")
323
+ col2.markdown("The processed image and details will appear here after uploading an image and clicking 'Apply Blur'.")