ApaCu commited on
Commit
f60e836
·
verified ·
1 Parent(s): 6682575

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +310 -0
app.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ # Image Upscale and Enhancement with Multiple Models
3
+ # By FebryEnsz
4
+ # SDK: Gradio
5
+ # Hosted on Hugging Face Spaces
6
+
7
+ import gradio as gr
8
+ import torch
9
+ from PIL import Image
10
+ import numpy as np
11
+ from huggingface_hub import hf_hub_download
12
+ import cv2
13
+ import os # Import os for path handling
14
+
15
+ # --- Dependency Imports (Need to be installed via pip or manual clone) ---
16
+ # BasicSR related imports (for SwinIR, EDSR, CodeFormer utilities)
17
+ try:
18
+ from basicsr.archs.swinir_arch import SwinIR as SwinIR_Arch
19
+ from basicsr.archs.edsr_arch import EDSR as EDSR_Arch
20
+ from basicsr.utils import img2tensor, tensor2img
21
+ BASESR_AVAILABLE = True
22
+ except ImportError:
23
+ print("Warning: basicsr not found. SwinIR, EDSR, and CodeFormer (using basicsr utils) will not be available.")
24
+ BASESR_AVAILABLE = False
25
+
26
+ # RealESRGAN import
27
+ try:
28
+ from realesrgan import RealESRGAN
29
+ REALESRGAN_AVAILABLE = True
30
+ except ImportError:
31
+ print("Warning: realesrgan not found. Real-ESRGAN-x4 will not be available.")
32
+ REALESRGAN_AVAILABLE = False
33
+
34
+ # CodeFormer import (This assumes CodeFormer is installed and importable,
35
+ # or integrated into basicsr's structure) - often requires manual setup.
36
+ # We will use basicsr's utilities for CodeFormer if available, and try a direct import if possible.
37
+ try:
38
+ # Attempting a common import path if CodeFormer is installed separately
39
+ from CodeFormer import CodeFormer # Adjust import based on your CodeFormer install
40
+ CODEFORMER_AVAILABLE = True
41
+ except ImportError:
42
+ print("Warning: CodeFormer not found. CodeFormer (Face Enhancement) will not be available.")
43
+ CODEFORMER_AVAILABLE = False
44
+
45
+
46
+ # --- Model Configuration ---
47
+ # Dictionary of available models and their configuration
48
+ # format: "UI Name": {"repo_id": "hf_repo_id", "filename": "weight_filename", "type": "upscale" or "face"}
49
+ MODEL_CONFIGS = {
50
+ "Real-ESRGAN-x4": {"repo_id": "RealESRGAN/RealESRGAN_x4plus", "filename": "RealESRGAN_x4plus.pth", "type": "upscale", "scale": 4} if REALESRGAN_AVAILABLE else None,
51
+ "SwinIR-4x": {"repo_id": "SwinIR/SwinIR-Large", "filename": "SwinIR_4x.pth", "type": "upscale", "scale": 4} if BASESR_AVAILABLE else None,
52
+ "EDSR-x4": {"repo_id": "EDSR/edsr_x4", "filename": "edsr_x4.pth", "type": "upscale", "scale": 4} if BASESR_AVAILABLE else None,
53
+ # Note: CodeFormer often requires its own setup. Assuming basicsr utils might help,
54
+ # but its core logic is in the CodeFormer library.
55
+ "CodeFormer (Face Enhancement)": {"repo_id": "CodeFormer/codeformer", "filename": "codeformer.pth", "type": "face"} if CODEFORMER_AVAILABLE or BASESR_AVAILABLE else None, # Use CodeFormer if installed, otherwise rely on basicsr utilities being present
56
+ }
57
+
58
+ # Filter out unavailable models
59
+ MODEL_CONFIGS = {k: v for k, v in MODEL_CONFIGS.items() if v is not None}
60
+
61
+ # --- Model Loading Cache ---
62
+ # Use a simple cache to avoid reloading the same model multiple times
63
+ cached_model = {}
64
+ cached_model_name = None
65
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+ print(f"Using device: {device}")
67
+
68
+ # Function to load the selected model
69
+ def load_model(model_name):
70
+ global cached_model, cached_model_name
71
+
72
+ if model_name == cached_model_name and cached_model is not None:
73
+ print(f"Using cached model: {model_name}")
74
+ return cached_model, MODEL_CONFIGS[model_name]['type']
75
+
76
+ print(f"Loading model: {model_name}")
77
+ config = MODEL_CONFIGS.get(model_name)
78
+ if config is None:
79
+ return None, f"Error: Model '{model_name}' not supported or dependencies missing."
80
+
81
+ try:
82
+ model_type = config['type']
83
+ model_path = hf_hub_download(repo_id=config['repo_id'], filename=config['filename'])
84
+
85
+ if model_name == "Real-ESRGAN-x4":
86
+ if not REALESRGAN_AVAILABLE: raise ImportError("realesrgan not installed.")
87
+ model = RealESRGAN(device, scale=config['scale'])
88
+ model.load_weights(model_path)
89
+
90
+ elif model_name == "SwinIR-4x":
91
+ if not BASESR_AVAILABLE: raise ImportError("basicsr not installed.")
92
+ # SwinIR requires specific initialization parameters
93
+ # These match the SwinIR_4x.pth model from the repo
94
+ model = SwinIR_Arch(
95
+ upscale=config['scale'], in_chans=3, img_size=64, window_size=8,
96
+ compress_ratio= -1, dilate_basis=-1, res_range=-1, attn_type='linear'
97
+ )
98
+ # Load weights, handling potential key mismatches if necessary
99
+ pretrained_dict = torch.load(model_path, map_location=device)
100
+ model.load_state_dict(pretrained_dict, strict=True) # strict=False if keys might mismatch
101
+ model.eval() # Set to evaluation mode
102
+ model.to(device)
103
+
104
+ elif model_name == "EDSR-x4":
105
+ if not BASESR_AVAILABLE: raise ImportError("basicsr not installed.")
106
+ # EDSR architecture needs scale, num_feat, num_block
107
+ # Assuming typical values for EDSR_x4 from the repo
108
+ model = EDSR_Arch(num_feat=64, num_block=16, upscale=config['scale'])
109
+ pretrained_dict = torch.load(model_path, map_location=device)
110
+ model.load_state_dict(pretrained_dict, strict=True)
111
+ model.eval()
112
+ model.to(device)
113
+
114
+ elif model_name == "CodeFormer (Face Enhancement)":
115
+ if not (CODEFORMER_AVAILABLE or BASESR_AVAILABLE): raise ImportError("CodeFormer or basicsr not installed.")
116
+ # CodeFormer loading is more complex, often requiring instantiation with specific args
117
+ # and potentially related models (like GFPGAN for background).
118
+ # For simplicity here, we assume a basic CodeFormer instance can be created.
119
+ # This part might need adjustment based on your CodeFormer installation.
120
+ if CODEFORMER_AVAILABLE:
121
+ # This is a simplified instantiation; a real CodeFormer usage might need more args
122
+ model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layer=9)
123
+ pretrained_dict = torch.load(model_path, map_location=device)['params_ema'] # CodeFormer often saves params_ema
124
+ # Need to handle potential DataParallel prefix if saved from DP
125
+ keys = list(pretrained_dict.keys())
126
+ if keys and keys[0].startswith('module.'):
127
+ pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items()}
128
+ model.load_state_dict(pretrained_dict, strict=True)
129
+ model.eval()
130
+ model.to(device)
131
+ elif BASESR_AVAILABLE:
132
+ # Fallback: If CodeFormer library isn't directly importable but basicsr is,
133
+ # we *cannot* instantiate the CodeFormer model itself unless basicsr provides it.
134
+ # This option is likely only possible if CodeFormer is installed *within* a basicsr environment
135
+ # or if basicsr provides the architecture. Given the complexity, let's just raise an error
136
+ # if CODEFORMER_AVAILABLE is False.
137
+ raise ImportError("CodeFormer library not found. BasicSR utilities alone are not enough to instantiate CodeFormer.")
138
+
139
+
140
+ else:
141
+ raise ValueError(f"Configuration missing for model: {model_name}")
142
+
143
+ # Cache the loaded model
144
+ cached_model = model
145
+ cached_model_name = model_name
146
+
147
+ return model, model_type
148
+
149
+ except ImportError as ie:
150
+ print(f"Dependency missing for {model_name}: {ie}")
151
+ return None, f"Error: Missing dependency - {ie}. Please ensure model libraries are installed."
152
+ except Exception as e:
153
+ print(f"Error loading model {model_name}: {e}")
154
+ # Clear cache on error
155
+ cached_model = None
156
+ cached_model_name = None
157
+ return None, f"Error loading model: {str(e)}"
158
+
159
+ # Function to preprocess image (PIL RGB to OpenCV BGR numpy)
160
+ def preprocess_image(image: Image.Image) -> np.ndarray:
161
+ img = np.array(image)
162
+ # OpenCV uses BGR, PIL uses RGB. Need conversion.
163
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
164
+ return img
165
+
166
+ # Function to postprocess image (OpenCV BGR numpy to PIL RGB)
167
+ def postprocess_image(img: np.ndarray) -> Image.Image:
168
+ # Ensure image is in the correct range and type before converting
169
+ if img.dtype != np.uint8:
170
+ img = np.clip(img, 0, 255).astype(np.uint8)
171
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
172
+ return Image.fromarray(img)
173
+
174
+ # Main processing function
175
+ def enhance_image(image: Image.Image, model_name: str):
176
+ if image is None:
177
+ return "Please upload an image.", None
178
+
179
+ status_message = f"Processing image with {model_name}..."
180
+
181
+ # Load the selected model and its type
182
+ model, model_info = load_model(model_name)
183
+
184
+ if model is None:
185
+ # model_info contains the error message if loading failed
186
+ return model_info, None
187
+
188
+ model_type = model_info # model_info is the type string ('upscale' or 'face')
189
+
190
+ try:
191
+ # Preprocess the image (PIL RGB -> OpenCV BGR)
192
+ img_np_bgr = preprocess_image(image)
193
+
194
+ # Process based on model type and specific model implementation
195
+ if model_type == "upscale":
196
+ print(f"Applying {model_name} upscaling...")
197
+ if model_name == "Real-ESRGAN-x4":
198
+ # RealESRGAN works with uint8 BGR numpy directly
199
+ output_np_bgr = model.predict(img_np_bgr)
200
+ elif model_name in ["SwinIR-4x", "EDSR-x4"]:
201
+ if not BASESR_AVAILABLE:
202
+ raise ImportError(f"basicsr is required for {model_name}")
203
+ # These models often work with float tensors (0-1 range)
204
+ # Using basicsr utils: HWC BGR uint8 -> CHW RGB float (0-1) -> send to device
205
+ img_tensor = img2tensor(img_np_bgr.astype(np.float32) / 255., bgr2rgb=True, float32=True).unsqueeze(0).to(device)
206
+
207
+ with torch.no_grad():
208
+ output_tensor = model(img_tensor)
209
+
210
+ # Using basicsr utils: CHW RGB float (0-1) -> HWC RGB uint8 -> Convert to BGR for postprocessing
211
+ output_np_rgb = tensor2img(output_tensor, rgb2bgr=False, min_max=(0, 1))
212
+ output_np_bgr = cv2.cvtColor(output_np_rgb, cv2.COLOR_RGB2BGR)
213
+
214
+ else:
215
+ raise ValueError(f"Unknown upscale model: {model_name}")
216
+
217
+ status_message = f"Image upscaled successfully with {model_name}!"
218
+
219
+ elif model_type == "face":
220
+ print(f"Applying {model_name} face enhancement...")
221
+ if model_name == "CodeFormer (Face Enhancement)":
222
+ if not (CODEFORMER_AVAILABLE or BASESR_AVAILABLE):
223
+ raise ImportError(f"CodeFormer or basicsr is required for {model_name}")
224
+ # CodeFormer's enhance method typically expects uint8 BGR numpy
225
+ # It might return multiple outputs, the first is usually the enhanced image
226
+ # Example: CodeFormer's inference script might return (restored_img, bboxes)
227
+ # We assume the image is the first element.
228
+ # Note: CodeFormer often needs additional setup/parameters for GFPGAN, etc.
229
+ # This is a simplified call.
230
+ # Ensure model is on correct device before call
231
+ if next(model.parameters()).device != device:
232
+ model.to(device)
233
+
234
+ # A minimal CodeFormer enhancement might look like this, but the actual API
235
+ # depends on the CodeFormer library version/structure you're using.
236
+ # The original CodeFormer repo's inference takes numpy BGR.
237
+ # This is a *placeholder* call assuming such a method exists and works like this:
238
+ output_np_bgr = model.enhance(img_np_bgr, w=0.5, adain=True)[0] # w and adain are common params
239
+
240
+
241
+ else:
242
+ raise ValueError(f"Unknown face enhancement model: {model_name}")
243
+
244
+ status_message = f"Face enhancement applied successfully with {model_name}!"
245
+
246
+ # Postprocess the output image (OpenCV BGR -> PIL RGB)
247
+ enhanced_image = postprocess_image(output_np_bgr)
248
+
249
+ return status_message, enhanced_image
250
+
251
+ except ImportError as ie:
252
+ return f"Error processing image: Missing dependency - {ie}", None
253
+ except Exception as e:
254
+ print(f"Error during processing: {e}")
255
+ import traceback
256
+ traceback.print_exc() # Print full traceback for debugging
257
+ return f"Error processing image: {str(e)}", None
258
+
259
+ # Gradio interface
260
+ with gr.Blocks(title="Image Upscale & Enhancement - By FebryEnsz") as demo:
261
+ gr.Markdown(
262
+ """
263
+ # Image Upscale & Enhancement
264
+ **By FebryEnsz**
265
+
266
+ Upload an image and select a model to enhance it. Choose from multiple models for upscaling (to make it 'HD' or higher resolution) or face enhancement (to improve facial details and focus).
267
+
268
+ **Note:** This app requires specific Python libraries (`torch`, `basicsr`, `realesrgan`, `CodeFormer`) to be installed for all models to be available. If a model option is missing, its required library is not installed or found.
269
+ """
270
+ )
271
+
272
+ with gr.Row():
273
+ with gr.Column():
274
+ image_input = gr.Image(label="Upload Image", type="pil")
275
+
276
+ # Filter available choices based on loaded configs
277
+ available_models = list(MODEL_CONFIGS.keys())
278
+
279
+ if not available_models:
280
+ model_choice = gr.Textbox(label="Select Model", value="No models available. Check dependencies.", interactive=False)
281
+ enhance_button = gr.Button("Enhance Image", interactive=False)
282
+ print("No models are available because dependencies are missing.")
283
+ else:
284
+ model_choice = gr.Dropdown(
285
+ choices=available_models,
286
+ label="Select Model",
287
+ value=available_models[0] # Default to the first available model
288
+ )
289
+ # Removed scale_slider as models are fixed scale (x4)
290
+ enhance_button = gr.Button("Enhance Image")
291
+
292
+ with gr.Column():
293
+ output_text = gr.Textbox(label="Status", max_lines=2)
294
+ output_image = gr.Image(label="Enhanced Image")
295
+
296
+ # Connect the button to the processing function
297
+ if available_models: # Only connect if models are available
298
+ enhance_button.click(
299
+ fn=enhance_image,
300
+ inputs=[image_input, model_choice],
301
+ outputs=[output_text, output_image]
302
+ )
303
+
304
+ # Launch the Gradio app
305
+ if __name__ == "__main__":
306
+ # Set torch backend for potentially better performance on some systems
307
+ if torch.backends.mps.is_available():
308
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # Optional: enable fallback for MPS
309
+
310
+ demo.launch()