ApaCu commited on
Commit
eac6c90
·
verified ·
1 Parent(s): 59d7f5d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +427 -254
app.py CHANGED
@@ -6,311 +6,484 @@
6
 
7
  import gradio as gr
8
  import torch
9
- from PIL import Image
10
  import numpy as np
11
- from huggingface_hub import hf_hub_download
12
  import cv2
13
- import os # Import os for path handling
14
-
15
- # --- Dependency Imports (Need to be installed via pip or manual clone) ---
16
- # BasicSR related imports (for SwinIR, EDSR, CodeFormer utilities)
17
- # Wrap imports in try/except to handle missing libraries
18
- BASESR_AVAILABLE = False
19
- try:
20
- from basicsr.archs.swinir_arch import SwinIR as SwinIR_Arch
21
- from basicsr.archs.edsr_arch import EDSR as EDSR_Arch
22
- from basicsr.utils import img2tensor, tensor2img
23
- BASESR_AVAILABLE = True
24
- except ImportError:
25
- print("Warning: basicsr not found. SwinIR, EDSR, and CodeFormer (using basicsr utils) will not be available.")
26
-
27
- # RealESRGAN import
28
- REALESRGAN_AVAILABLE = False
29
- try:
30
- from realesrgan import RealESRGAN
31
- REALESRGAN_AVAILABLE = True
32
- except ImportError:
33
- print("Warning: realesrgan not found. Real-ESRGAN-x4 will not be available.")
34
-
35
- # CodeFormer import (Often requires manual setup or specific installation)
36
- # We assume it's importable if basicsr is available AND the CodeFormer library itself
37
- # was somehow installed (e.g., via cloning and manual setup).
38
- # Given the previous error, direct pip install from git often fails.
39
- # We'll primarily rely on basicsr utilities, but a proper CodeFormer instance
40
- # might still require its dedicated installation.
41
- CODEFORMER_AVAILABLE = False
42
- if BASESR_AVAILABLE: # CodeFormer often depends on basicsr utilities
43
- try:
44
- # Attempting a common import path if CodeFormer is installed separately
45
- # This might need adjustment based on your CodeFormer install method
46
- from CodeFormer import CodeFormer # Adjust import based on your CodeFormer install path
47
- CODEFORMER_AVAILABLE = True
48
- except ImportError:
49
- print("Warning: CodeFormer library not directly importable. CodeFormer model might not work correctly.")
50
- # If basicsr is available, we might still list the model but it might fail later if CodeFormer class isn't there
51
- pass # Allow BASESR_AVAILABLE to potentially enable the config entry
52
-
53
 
54
- # --- Model Configuration ---
55
- # Dictionary of available models and their configuration
56
- # format: "UI Name": {"repo_id": "hf_repo_id", "filename": "weight_filename", "type": "upscale" or "face", ...}
57
- MODEL_CONFIGS = {}
58
 
59
- if REALESRGAN_AVAILABLE:
60
- MODEL_CONFIGS["Real-ESRGAN-x4"] = {"repo_id": "RealESRGAN/RealESRGAN_x4plus", "filename": "RealESRGAN_x4plus.pth", "type": "upscale", "scale": 4}
 
 
61
 
62
- if BASESR_AVAILABLE:
63
- MODEL_CONFIGS["SwinIR-4x"] = {"repo_id": "SwinIR/SwinIR-Large", "filename": "SwinIR_4x.pth", "type": "upscale", "scale": 4}
64
- MODEL_CONFIGS["EDSR-x4"] = {"repo_id": "EDSR/edsr_x4", "filename": "edsr_x4.pth", "type": "upscale", "scale": 4}
65
- # Add CodeFormer config only if basicsr is available, and potentially CODEFORMER_AVAILABLE is True
66
- # Even if CODEFORMER_AVAILABLE is False, listing it might rely on basicsr providing necessary components (less likely)
67
- # Given installation issues, let's only add it if the library is actually importable.
68
- if CODEFORMER_AVAILABLE:
69
- MODEL_CONFIGS["CodeFormer (Face Enhancement)"] = {"repo_id": "CodeFormer/codeformer", "filename": "codeformer.pth", "type": "face"}
70
- else:
71
- print("CodeFormer (Face Enhancement) model will not be listed due to import issues.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # No need to filter anymore, configs are only added if available
 
 
 
 
 
74
 
75
- # --- Model Loading Cache ---
76
- cached_model = {}
77
- cached_model_name = None
78
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
79
- print(f"Using device: {device}") # This shows which device PyTorch detects
80
-
81
- # Function to load the selected model
82
- def load_model(model_name: str):
83
- global cached_model, cached_model_name
84
-
85
- if model_name == cached_model_name and cached_model is not None:
86
- print(f"Using cached model: {model_name}")
87
- return cached_model, MODEL_CONFIGS[model_name]['type']
88
-
89
- print(f"Loading model: {model_name}")
90
- config = MODEL_CONFIGS.get(model_name)
91
- if config is None:
92
- # This case should ideally not happen if UI choices are filtered,
93
- # but good for safety.
94
- return None, f"Error: Model '{model_name}' not configured or dependencies missing."
95
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
  try:
97
- model_type = config['type']
98
- model_path = hf_hub_download(repo_id=config['repo_id'], filename=config['filename'])
99
-
100
- if model_name == "Real-ESRGAN-x4":
101
- if not REALESRGAN_AVAILABLE: raise ImportError("realesrgan was not imported correctly.")
102
- model = RealESRGAN(device, scale=config['scale'])
103
- model.load_weights(model_path)
104
-
105
- elif model_name == "SwinIR-4x":
106
- if not BASESR_AVAILABLE: raise ImportError("basicsr was not imported correctly.")
107
- model = SwinIR_Arch(
108
- upscale=config['scale'], in_chans=3, img_size=64, window_size=8,
109
- compress_ratio= -1, dilate_basis=-1, res_range=-1, attn_type='linear'
110
  )
111
- pretrained_dict = torch.load(model_path, map_location=device)
112
- model.load_state_dict(pretrained_dict, strict=True)
113
- model.eval()
114
- model.to(device)
115
-
116
- elif model_name == "EDSR-x4":
117
- if not BASESR_AVAILABLE: raise ImportError("basicsr was not imported correctly.")
118
- model = EDSR_Arch(num_feat=64, num_block=16, upscale=config['scale'])
119
- pretrained_dict = torch.load(model_path, map_location=device)
120
- model.load_state_dict(pretrained_dict, strict=True)
121
- model.eval()
122
- model.to(device)
123
 
124
- elif model_name == "CodeFormer (Face Enhancement)":
125
- if not CODEFORMER_AVAILABLE:
126
- # This check is redundant if config is only added when available,
127
- # but good practice.
128
- raise ImportError("CodeFormer library was not imported correctly.")
129
-
130
- # Ensure model_path is correct (downloaded via hf_hub_download)
131
- # CodeFormer loading often needs specific handling for checkpoints (params_ema)
132
- # This part is sensitive to the exact CodeFormer version/structure
133
- # Assuming a similar loading pattern to basicsr models:
134
- if CODEFORMER_AVAILABLE:
135
- model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layer=9) # Basic instantiation
136
- pretrained_dict = torch.load(model_path, map_location=device)['params_ema']
137
- keys = list(pretrained_dict.keys())
138
- if keys and keys[0].startswith('module.'):
139
- pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items()}
140
- model.load_state_dict(pretrained_dict, strict=True)
141
- model.eval()
142
- model.to(device)
143
- else:
144
- # Fallback check, should not be reached if config is filtered
145
- raise ImportError("CodeFormer library not available.")
146
-
147
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  else:
149
- # This should not be reached with filtered configs
150
- raise ValueError(f"Configuration missing or invalid for model: {model_name}")
151
-
152
- # Cache the loaded model
153
- cached_model = model
154
- cached_model_name = model_name
155
-
156
- return model, model_type
157
-
158
- except ImportError as ie:
159
- # This catches errors if the library was *somehow* listed as available
160
- # but then failed on a deeper import within load_model
161
- print(f"Dependency check failed during load for {model_name}: {ie}")
162
- # Clear cache on error
163
- cached_model = None
164
- cached_model_name = None
165
- return None, f"Error: Dependency not fully available - {ie}. Model cannot be loaded."
166
  except Exception as e:
167
- print(f"Error loading model {model_name}: {e}")
168
  import traceback
169
- traceback.print_exc() # Print full traceback for debugging
170
- # Clear cache on error
171
- cached_model = None
172
- cached_model_name = None
173
- return None, f"Error loading model: {str(e)}"
 
 
 
174
 
175
- # Function to preprocess image (PIL RGB to OpenCV BGR numpy)
176
- def preprocess_image(image: Image.Image) -> np.ndarray:
177
- img = np.array(image)
178
- img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
179
- return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- # Function to postprocess image (OpenCV BGR numpy to PIL RGB)
182
- def postprocess_image(img: np.ndarray) -> Image.Image:
183
- if img.dtype != np.uint8:
184
- img = np.clip(img, 0, 255).astype(np.uint8)
185
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
186
- return Image.fromarray(img)
 
 
 
 
 
 
 
 
187
 
188
- # Main processing function
189
- def enhance_image(image: Image.Image, model_name: str):
190
- if image is None:
191
- # Return tuple of (message, image)
192
- return "Please upload an image.", None
 
 
 
 
 
 
 
193
 
194
- status_message = f"Processing image with {model_name}..."
 
 
195
 
196
- # Load the selected model and its type
197
- model, model_info = load_model(model_name)
198
 
199
- if model is None:
200
- # model_info contains the error message from load_model
201
- return model_info, None
202
 
203
- model_type = model_info # model_info is the type string ('upscale' or 'face')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
 
 
 
 
 
 
 
205
  try:
206
- # Preprocess the image (PIL RGB -> OpenCV BGR)
207
- img_np_bgr = preprocess_image(image)
 
 
208
 
209
- # Process based on model type and specific model implementation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
210
  if model_type == "upscale":
211
- print(f"Applying {model_name} upscaling...")
212
- if model_name == "Real-ESRGAN-x4":
213
- if not REALESRGAN_AVAILABLE: raise ImportError("Real-ESRGAN library not available.")
214
- output_np_bgr = model.predict(img_np_bgr)
215
- elif model_name in ["SwinIR-4x", "EDSR-x4"]:
216
- if not BASESR_AVAILABLE: raise ImportError(f"basicsr library not available for {model_name}")
217
- # HWC BGR uint8 -> CHW RGB float (0-1) -> send to device
218
- img_tensor = img2tensor(img_np_bgr.astype(np.float32) / 255., bgr2rgb=True, float32=True).unsqueeze(0).to(device)
219
-
220
- with torch.no_grad():
221
- output_tensor = model(img_tensor)
222
-
223
- # CHW RGB float (0-1) -> HWC RGB uint8 -> Convert to BGR
224
- output_np_rgb = tensor2img(output_tensor, rgb2bgr=False, min_max=(0, 1))
225
- output_np_bgr = cv2.cvtColor(output_np_rgb, cv2.COLOR_RGB2BGR)
226
-
227
  else:
228
- raise ValueError(f"Unknown upscale model type configuration: {model_name}")
229
-
230
- status_message = f"Image upscaled successfully with {model_name}!"
231
-
232
  elif model_type == "face":
233
- print(f"Applying {model_name} face enhancement...")
234
- if model_name == "CodeFormer (Face Enhancement)":
235
- if not CODEFORMER_AVAILABLE: raise ImportError("CodeFormer library not available.")
236
- # Ensure model is on correct device
237
- if next(model.parameters()).device != device:
238
- model.to(device)
239
-
240
- # Call the enhance method (adjust parameters w, adain as needed)
241
- # CodeFormer enhance typically returns a tuple, first element is image
242
- output_np_bgr = model.enhance(img_np_bgr, w=0.5, adain=True)[0] # Example call
 
 
 
 
 
 
243
  else:
244
- raise ValueError(f"Unknown face enhancement model type configuration: {model_name}")
245
-
246
- status_message = f"Face enhancement applied successfully with {model_name}!"
247
-
248
- # Postprocess the output image (OpenCV BGR -> PIL RGB)
249
- enhanced_image = postprocess_image(output_np_bgr)
 
 
250
 
251
- return status_message, enhanced_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
- except ImportError as ie:
254
- # This catches errors if the library was imported initially but
255
- # failed later when its functions/classes were called.
256
- print(f"Error processing image due to missing dependency call: {ie}")
257
- return f"Error processing image: Required library function not found - {ie}", None
258
  except Exception as e:
259
- print(f"Error during processing: {e}")
260
  import traceback
261
- traceback.print_exc() # Print full traceback for debugging
262
- return f"Error processing image: {str(e)}", None
263
 
264
  # Gradio interface
265
  with gr.Blocks(title="Image Upscale & Enhancement - By FebryEnsz") as demo:
266
  gr.Markdown(
267
  """
268
- # Image Upscale & Enhancement
269
- **By FebryEnsz**
270
 
271
- Upload an image and select a model to enhance it. Choose from multiple models for upscaling (to make it 'HD' or higher resolution) or face enhancement (to improve facial details and focus).
272
 
273
- **Note:** This app requires specific Python libraries (`torch`, `basicsr`, `realesrgan`, `CodeFormer`) to be installed for all models to be available. If a model option is missing, its required library was not successfully installed or found during startup. Check the Space build logs for installation errors.
 
 
 
274
  """
275
  )
276
 
277
- # Filter available choices based on loaded configs
278
- available_models = list(MODEL_CONFIGS.keys())
279
-
280
  with gr.Row():
281
- with gr.Column():
282
  image_input = gr.Image(label="Upload Image", type="pil")
283
 
284
- if not available_models:
285
- model_choice = gr.Textbox(label="Select Model", value="No models available. Check build logs for dependency errors.", interactive=False)
286
- enhance_button = gr.Button("Enhance Image", interactive=False)
287
- print("No models are available because dependencies are missing.")
288
- else:
289
  model_choice = gr.Dropdown(
290
- choices=available_models,
291
- label="Select Model",
292
- value=available_models[0] # Default to the first available model
293
  )
294
- # Removed scale_slider
295
- enhance_button = gr.Button("Enhance Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
- with gr.Column():
298
- output_text = gr.Textbox(label="Status", max_lines=3)
 
 
299
  output_image = gr.Image(label="Enhanced Image")
300
 
301
- # Connect the button to the processing function
302
- if available_models: # Only connect if models are available
303
- enhance_button.click(
304
- fn=enhance_image,
305
- inputs=[image_input, model_choice],
306
- outputs=[output_text, output_image]
307
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
308
 
309
- # Launch the Gradio app
310
  if __name__ == "__main__":
311
- # Set torch backend for potentially better performance on some systems
312
- # Removed MPS fallback for simplicity unless specifically needed and tested
313
- # if torch.backends.mps.is_available():
314
- # os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
315
-
316
  demo.launch()
 
6
 
7
  import gradio as gr
8
  import torch
 
9
  import numpy as np
10
+ from PIL import Image, ImageEnhance
11
  import cv2
12
+ import os
13
+ import sys
14
+ import subprocess
15
+ import time
16
+ from huggingface_hub import hf_hub_download
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Create cache directory for models
19
+ CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "image_enhancer")
20
+ os.makedirs(CACHE_DIR, exist_ok=True)
 
21
 
22
+ # Set up logging
23
+ import logging
24
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
25
+ logger = logging.getLogger(__name__)
26
 
27
+ # Install required packages at runtime for Hugging Face Spaces
28
+ def install_dependencies():
29
+ logger.info("Checking and installing dependencies...")
30
+
31
+ packages_to_install = [
32
+ "opencv-python",
33
+ "opencv-contrib-python", # For dnn_superres module
34
+ "numpy",
35
+ "pillow",
36
+ "torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu",
37
+ "facexlib",
38
+ "basicsr",
39
+ "gfpgan",
40
+ "realesrgan"
41
+ ]
42
+
43
+ for package in packages_to_install:
44
+ try:
45
+ logger.info(f"Installing {package}")
46
+ subprocess.check_call([sys.executable, "-m", "pip", "install", package])
47
+ except Exception as e:
48
+ logger.warning(f"Error installing {package}: {str(e)}")
49
+
50
+ logger.info("Dependencies installation complete")
51
 
52
+ # Try to install dependencies on startup
53
+ try:
54
+ install_dependencies()
55
+ time.sleep(2) # Give some time for packages to settle
56
+ except Exception as e:
57
+ logger.error(f"Failed to install dependencies: {str(e)}")
58
 
59
+ # Check for GPU or CPU
 
 
60
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ logger.info(f"Using device: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
+ # Dictionary of available models and their configuration
64
+ MODEL_OPTIONS = {
65
+ "OpenCV Super Resolution": {
66
+ "type": "upscale",
67
+ "method": "opencv",
68
+ "scale": 4
69
+ },
70
+ "Real-ESRGAN-x4": {
71
+ "repo_id": "xinntao/Real-ESRGAN",
72
+ "filename": "RealESRGAN_x4plus.pth",
73
+ "type": "upscale",
74
+ "method": "realesrgan",
75
+ "scale": 4
76
+ },
77
+ "GFPGAN (Face Enhancement)": {
78
+ "repo_id": "TencentARC/GFPGAN",
79
+ "filename": "GFPGANv1.4.pth",
80
+ "type": "face",
81
+ "method": "gfpgan",
82
+ "scale": 1
83
+ },
84
+ "HDR Enhancement": {
85
+ "type": "hdr",
86
+ "method": "custom",
87
+ "scale": 1
88
+ }
89
+ }
90
+
91
+ # Cache for loaded models
92
+ model_cache = {}
93
+
94
+ # Function to load the selected model with robust fallbacks
95
+ def load_model(model_name):
96
+ global model_cache
97
+
98
+ # Return cached model if available
99
+ if model_name in model_cache:
100
+ logger.info(f"Using cached model: {model_name}")
101
+ return model_cache[model_name]
102
+
103
+ logger.info(f"Loading model: {model_name}")
104
+ config = MODEL_OPTIONS.get(model_name)
105
+ if not config:
106
+ return None, f"Model {model_name} not found in configuration"
107
+
108
+ model_type = config["type"]
109
+
110
  try:
111
+ # OpenCV based models (always available as fallback)
112
+ if config["method"] == "opencv":
113
+ logger.info("Loading OpenCV Super Resolution model")
114
+ sr = cv2.dnn_superres.DnnSuperResImpl_create()
115
+
116
+ # Use EDSR as default model
117
+ model_path = hf_hub_download(
118
+ repo_id="eugenesiow/edsr",
119
+ filename="EDSR_x4.pb",
120
+ cache_dir=CACHE_DIR
 
 
 
121
  )
 
 
 
 
 
 
 
 
 
 
 
 
122
 
123
+ sr.readModel(model_path)
124
+ sr.setModel("edsr", 4)
125
+
126
+ # Set backend to cuda if available
127
+ if torch.cuda.is_available():
128
+ sr.setPreferableBackend(cv2.dnn.DNN_BACKEND_CUDA)
129
+ sr.setPreferableTarget(cv2.dnn.DNN_TARGET_CUDA)
130
+
131
+ model_cache[model_name] = (sr, model_type)
132
+ return sr, model_type
133
+
134
+ # Real-ESRGAN models
135
+ elif config["method"] == "realesrgan":
136
+ try:
137
+ from realesrgan import RealESRGAN
138
+ logger.info("Loading Real-ESRGAN model")
139
+
140
+ model_path = hf_hub_download(
141
+ repo_id=config["repo_id"],
142
+ filename=config["filename"],
143
+ cache_dir=CACHE_DIR
144
+ )
145
+
146
+ model = RealESRGAN(device, scale=config["scale"])
147
+ model.load_weights(model_path)
148
+
149
+ model_cache[model_name] = (model, model_type)
150
+ return model, model_type
151
+ except ImportError:
152
+ logger.warning("RealESRGAN not available, falling back to OpenCV")
153
+ return load_model("OpenCV Super Resolution")
154
+
155
+ # GFPGAN for face enhancement
156
+ elif config["method"] == "gfpgan":
157
+ try:
158
+ from gfpgan import GFPGANer
159
+ logger.info("Loading GFPGAN model")
160
+
161
+ model_path = hf_hub_download(
162
+ repo_id=config["repo_id"],
163
+ filename=config["filename"],
164
+ cache_dir=CACHE_DIR
165
+ )
166
+
167
+ face_enhancer = GFPGANer(
168
+ model_path=model_path,
169
+ upscale=config["scale"],
170
+ arch='clean',
171
+ channel_multiplier=2,
172
+ bg_upsampler=None
173
+ )
174
+
175
+ model_cache[model_name] = (face_enhancer, model_type)
176
+ return face_enhancer, model_type
177
+ except ImportError:
178
+ logger.warning("GFPGAN not available, falling back to OpenCV")
179
+ return load_model("OpenCV Super Resolution")
180
+
181
+ # HDR Enhancement (custom implementation)
182
+ elif config["method"] == "custom":
183
+ # No model to load for custom HDR
184
+ model_cache[model_name] = (None, model_type)
185
+ return None, model_type
186
+
187
  else:
188
+ raise ValueError(f"Unknown model method: {config['method']}")
189
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  except Exception as e:
191
+ logger.error(f"Error loading model {model_name}: {str(e)}")
192
  import traceback
193
+ traceback.print_exc()
194
+
195
+ # Always provide a fallback method
196
+ if model_name != "OpenCV Super Resolution":
197
+ logger.info("Falling back to OpenCV Super Resolution")
198
+ return load_model("OpenCV Super Resolution")
199
+ else:
200
+ return None, f"Failed to load model: {str(e)}"
201
 
202
+ # Function to preprocess image for processing
203
+ def preprocess_image(image):
204
+ """Convert PIL image to numpy array for processing"""
205
+ if image is None:
206
+ return None
207
+
208
+ if isinstance(image, Image.Image):
209
+ # Convert PIL image to numpy array
210
+ img = np.array(image)
211
+ else:
212
+ img = image
213
+
214
+ # Handle grayscale images by converting to RGB
215
+ if len(img.shape) == 2:
216
+ img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
217
+
218
+ # Handle RGBA images by removing alpha channel
219
+ if img.shape[2] == 4:
220
+ img = img[:, :, :3]
221
+
222
+ # Convert RGB to BGR for OpenCV processing
223
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
224
+
225
+ return img_bgr
226
 
227
+ # Function to postprocess image for display
228
+ def postprocess_image(img_bgr):
229
+ """Convert processed BGR image back to RGB PIL image"""
230
+ if img_bgr is None:
231
+ return None
232
+
233
+ # Ensure image is uint8
234
+ if img_bgr.dtype != np.uint8:
235
+ img_bgr = np.clip(img_bgr, 0, 255).astype(np.uint8)
236
+
237
+ # Convert BGR to RGB for PIL
238
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
239
+
240
+ return Image.fromarray(img_rgb)
241
 
242
+ # HDR enhancement function
243
+ def enhance_hdr(img_bgr, strength=1.0):
244
+ """Custom HDR enhancement using OpenCV"""
245
+ # Convert BGR to RGB
246
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
247
+
248
+ # Convert to float32 for processing
249
+ img_float = img_rgb.astype(np.float32) / 255.0
250
+
251
+ # Convert to LAB color space for better contrast enhancement
252
+ img_lab = cv2.cvtColor(img_float, cv2.COLOR_RGB2LAB)
253
+ l, a, b = cv2.split(img_lab)
254
 
255
+ # Apply CLAHE (Contrast Limited Adaptive Histogram Equalization)
256
+ clahe = cv2.createCLAHE(clipLimit=3.0 * strength, tileGridSize=(8, 8))
257
+ l_enhanced = clahe.apply(np.clip(l * 255, 0, 255).astype(np.uint8)) / 255.0
258
 
259
+ # Blend original and enhanced L channel
260
+ l = l * (1 - strength) + l_enhanced * strength
261
 
262
+ # Merge channels
263
+ img_lab_enhanced = cv2.merge([l, a, b])
264
+ img_rgb_enhanced = cv2.cvtColor(img_lab_enhanced, cv2.COLOR_LAB2RGB)
265
 
266
+ # Add vibrance (increase saturation of low-saturation areas)
267
+ hsv = cv2.cvtColor(img_rgb_enhanced, cv2.COLOR_RGB2HSV)
268
+ h, s, v = cv2.split(hsv)
269
+
270
+ # Increase saturation adaptively (more for lower saturation, less for already saturated pixels)
271
+ saturation_factor = 0.3 * strength
272
+ s_enhanced = np.clip(s * (1 + saturation_factor * (1 - s)), 0, 1)
273
+
274
+ # Increase brightness slightly
275
+ v_enhanced = np.clip(v * (1 + 0.1 * strength), 0, 1)
276
+
277
+ # Merge HSV channels and convert back to RGB
278
+ hsv_enhanced = cv2.merge([h, s_enhanced, v_enhanced])
279
+ img_enhanced = cv2.cvtColor(hsv_enhanced, cv2.COLOR_HSV2RGB)
280
+
281
+ # Apply subtle detail enhancement
282
+ kernel_size = 5
283
+ blur = cv2.GaussianBlur(img_enhanced, (kernel_size, kernel_size), 0)
284
+ detail = img_enhanced - blur
285
+ img_enhanced = np.clip(img_enhanced + detail * (0.5 * strength), 0, 1)
286
+
287
+ # Convert back to BGR for output
288
+ img_enhanced = (img_enhanced * 255).astype(np.uint8)
289
+ img_bgr_enhanced = cv2.cvtColor(img_enhanced, cv2.COLOR_RGB2BGR)
290
+
291
+ return img_bgr_enhanced
292
 
293
+ # Main image enhancement function
294
+ def enhance_image(image, model_name, strength=1.0, denoise=0.0, sharpen=0.0):
295
+ """Enhance image using selected model with additional processing options"""
296
+ if image is None:
297
+ return "Please upload an image.", None
298
+
299
  try:
300
+ # Load model
301
+ model, model_type = load_model(model_name)
302
+ if isinstance(model_type, str) and model_type.startswith("Failed"):
303
+ return model_type, None
304
 
305
+ # Preprocess image
306
+ img_bgr = preprocess_image(image)
307
+ if img_bgr is None:
308
+ return "Failed to process image", None
309
+
310
+ # Apply denoising if requested
311
+ if denoise > 0:
312
+ strength_value = int(denoise * 10)
313
+ img_bgr = cv2.fastNlMeansDenoisingColored(
314
+ img_bgr, None,
315
+ h=strength_value,
316
+ hColor=strength_value,
317
+ templateWindowSize=7,
318
+ searchWindowSize=21
319
+ )
320
+
321
+ # Process based on model type
322
  if model_type == "upscale":
323
+ logger.info(f"Upscaling image with {model_name}")
324
+
325
+ if model_name == "OpenCV Super Resolution":
326
+ # OpenCV super resolution
327
+ output_bgr = model.upsample(img_bgr)
328
+
329
+ elif model_name == "Real-ESRGAN-x4":
330
+ # Real-ESRGAN upscaling
331
+ try:
332
+ output_bgr = model.predict(img_bgr)
333
+ except Exception as e:
334
+ logger.error(f"Error with Real-ESRGAN: {str(e)}")
335
+ # Fall back to OpenCV
336
+ fallback_model, _ = load_model("OpenCV Super Resolution")
337
+ output_bgr = fallback_model.upsample(img_bgr)
338
+
339
  else:
340
+ # Default to OpenCV upscaling
341
+ sr = cv2.dnn_superres.DnnSuperResImpl_create()
342
+ sr.upsample(img_bgr)
343
+
344
  elif model_type == "face":
345
+ logger.info(f"Enhancing face with {model_name}")
346
+
347
+ if model_name == "GFPGAN (Face Enhancement)":
348
+ try:
349
+ # GFPGAN returns (cropped_faces, restored_faces, restored_img)
350
+ _, _, output_bgr = model.enhance(
351
+ img_bgr,
352
+ has_aligned=False,
353
+ only_center_face=False,
354
+ paste_back=True
355
+ )
356
+ except Exception as e:
357
+ logger.error(f"Error with GFPGAN: {str(e)}")
358
+ # Fall back to basic upscaling
359
+ fallback_model, _ = load_model("OpenCV Super Resolution")
360
+ output_bgr = fallback_model.upsample(img_bgr)
361
  else:
362
+ # Default upscaling for face if specific model fails
363
+ sr = cv2.dnn_superres.DnnSuperResImpl_create()
364
+ output_bgr = sr.upsample(img_bgr)
365
+
366
+ elif model_type == "hdr":
367
+ logger.info("Applying HDR enhancement")
368
+ # Custom HDR enhancement
369
+ output_bgr = enhance_hdr(img_bgr, strength=strength)
370
 
371
+ else:
372
+ return f"Unknown model type: {model_type}", None
373
+
374
+ # Apply sharpening if requested
375
+ if sharpen > 0:
376
+ sharpen_kernel = np.array([
377
+ [-1, -1, -1],
378
+ [-1, 9 + sharpen * 2, -1],
379
+ [-1, -1, -1]
380
+ ])
381
+ output_bgr = cv2.filter2D(output_bgr, -1, sharpen_kernel)
382
+
383
+ # Post-process and return image
384
+ enhanced_image = postprocess_image(output_bgr)
385
+
386
+ return "Image enhanced successfully!", enhanced_image
387
 
 
 
 
 
 
388
  except Exception as e:
389
+ logger.error(f"Error processing image: {str(e)}")
390
  import traceback
391
+ traceback.print_exc()
392
+ return f"Error: {str(e)}", None
393
 
394
  # Gradio interface
395
  with gr.Blocks(title="Image Upscale & Enhancement - By FebryEnsz") as demo:
396
  gr.Markdown(
397
  """
398
+ # 🖼️ Image Upscale & Enhancement
399
+ ### By FebryEnsz
400
 
401
+ Upload an image and enhance it with AI-powered upscaling and enhancement.
402
 
403
+ **Features:**
404
+ - Super-resolution upscaling (4x)
405
+ - Face enhancement for portraits
406
+ - HDR enhancement for better contrast and details
407
  """
408
  )
409
 
 
 
 
410
  with gr.Row():
411
+ with gr.Column(scale=1):
412
  image_input = gr.Image(label="Upload Image", type="pil")
413
 
414
+ with gr.Box():
415
+ gr.Markdown("### Enhancement Options")
 
 
 
416
  model_choice = gr.Dropdown(
417
+ choices=list(MODEL_OPTIONS.keys()),
418
+ label="Model Selection",
419
+ value="OpenCV Super Resolution"
420
  )
421
+
422
+ with gr.Accordion("Advanced Settings", open=False):
423
+ strength_slider = gr.Slider(
424
+ minimum=0.1,
425
+ maximum=1.0,
426
+ step=0.1,
427
+ label="Enhancement Strength",
428
+ value=0.8,
429
+ )
430
+
431
+ denoise_slider = gr.Slider(
432
+ minimum=0.0,
433
+ maximum=1.0,
434
+ step=0.1,
435
+ label="Noise Reduction",
436
+ value=0.0,
437
+ )
438
+
439
+ sharpen_slider = gr.Slider(
440
+ minimum=0.0,
441
+ maximum=1.0,
442
+ step=0.1,
443
+ label="Sharpening",
444
+ value=0.0,
445
+ )
446
 
447
+ enhance_button = gr.Button("✨ Enhance Image", variant="primary")
448
+
449
+ with gr.Column(scale=1):
450
+ output_text = gr.Textbox(label="Status")
451
  output_image = gr.Image(label="Enhanced Image")
452
 
453
+ # Handle model change to update UI
454
+ def on_model_change(model_name):
455
+ model_config = MODEL_OPTIONS.get(model_name, {})
456
+ model_type = model_config.get("type", "")
457
+
458
+ # Update UI based on model type
459
+ if model_type == "hdr":
460
+ return gr.update(visible=True, label="HDR Intensity")
461
+ elif model_type == "face":
462
+ return gr.update(visible=True, label="Enhancement Strength")
463
+ else:
464
+ return gr.update(visible=True, label="Enhancement Strength")
465
+
466
+ model_choice.change(on_model_change, inputs=[model_choice], outputs=[strength_slider])
467
+
468
+ # Connect button to function
469
+ enhance_button.click(
470
+ fn=enhance_image,
471
+ inputs=[image_input, model_choice, strength_slider, denoise_slider, sharpen_slider],
472
+ outputs=[output_text, output_image]
473
+ )
474
+
475
+ # Footer information
476
+ gr.Markdown(
477
+ """
478
+ ### Tips
479
+ - For best results with face enhancement, ensure faces are clearly visible
480
+ - HDR enhancement works best with images that have both bright and dark areas
481
+ - For noisy images, try increasing the noise reduction slider
482
+
483
+ ---
484
+ Version 2.0 | Running on: """ + ("GPU 🚀" if torch.cuda.is_available() else "CPU ⚙️")
485
+ )
486
 
487
+ # Launch the app
488
  if __name__ == "__main__":
 
 
 
 
 
489
  demo.launch()