Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,310 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# app.py
|
2 |
+
# Image Upscale and Enhancement with Multiple Models
|
3 |
+
# By FebryEnsz
|
4 |
+
# SDK: Gradio
|
5 |
+
# Hosted on Hugging Face Spaces
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
from huggingface_hub import hf_hub_download
|
12 |
+
import cv2
|
13 |
+
import os # Import os for path handling
|
14 |
+
|
15 |
+
# --- Dependency Imports (Need to be installed via pip or manual clone) ---
|
16 |
+
# BasicSR related imports (for SwinIR, EDSR, CodeFormer utilities)
|
17 |
+
try:
|
18 |
+
from basicsr.archs.swinir_arch import SwinIR as SwinIR_Arch
|
19 |
+
from basicsr.archs.edsr_arch import EDSR as EDSR_Arch
|
20 |
+
from basicsr.utils import img2tensor, tensor2img
|
21 |
+
BASESR_AVAILABLE = True
|
22 |
+
except ImportError:
|
23 |
+
print("Warning: basicsr not found. SwinIR, EDSR, and CodeFormer (using basicsr utils) will not be available.")
|
24 |
+
BASESR_AVAILABLE = False
|
25 |
+
|
26 |
+
# RealESRGAN import
|
27 |
+
try:
|
28 |
+
from realesrgan import RealESRGAN
|
29 |
+
REALESRGAN_AVAILABLE = True
|
30 |
+
except ImportError:
|
31 |
+
print("Warning: realesrgan not found. Real-ESRGAN-x4 will not be available.")
|
32 |
+
REALESRGAN_AVAILABLE = False
|
33 |
+
|
34 |
+
# CodeFormer import (This assumes CodeFormer is installed and importable,
|
35 |
+
# or integrated into basicsr's structure) - often requires manual setup.
|
36 |
+
# We will use basicsr's utilities for CodeFormer if available, and try a direct import if possible.
|
37 |
+
try:
|
38 |
+
# Attempting a common import path if CodeFormer is installed separately
|
39 |
+
from CodeFormer import CodeFormer # Adjust import based on your CodeFormer install
|
40 |
+
CODEFORMER_AVAILABLE = True
|
41 |
+
except ImportError:
|
42 |
+
print("Warning: CodeFormer not found. CodeFormer (Face Enhancement) will not be available.")
|
43 |
+
CODEFORMER_AVAILABLE = False
|
44 |
+
|
45 |
+
|
46 |
+
# --- Model Configuration ---
|
47 |
+
# Dictionary of available models and their configuration
|
48 |
+
# format: "UI Name": {"repo_id": "hf_repo_id", "filename": "weight_filename", "type": "upscale" or "face"}
|
49 |
+
MODEL_CONFIGS = {
|
50 |
+
"Real-ESRGAN-x4": {"repo_id": "RealESRGAN/RealESRGAN_x4plus", "filename": "RealESRGAN_x4plus.pth", "type": "upscale", "scale": 4} if REALESRGAN_AVAILABLE else None,
|
51 |
+
"SwinIR-4x": {"repo_id": "SwinIR/SwinIR-Large", "filename": "SwinIR_4x.pth", "type": "upscale", "scale": 4} if BASESR_AVAILABLE else None,
|
52 |
+
"EDSR-x4": {"repo_id": "EDSR/edsr_x4", "filename": "edsr_x4.pth", "type": "upscale", "scale": 4} if BASESR_AVAILABLE else None,
|
53 |
+
# Note: CodeFormer often requires its own setup. Assuming basicsr utils might help,
|
54 |
+
# but its core logic is in the CodeFormer library.
|
55 |
+
"CodeFormer (Face Enhancement)": {"repo_id": "CodeFormer/codeformer", "filename": "codeformer.pth", "type": "face"} if CODEFORMER_AVAILABLE or BASESR_AVAILABLE else None, # Use CodeFormer if installed, otherwise rely on basicsr utilities being present
|
56 |
+
}
|
57 |
+
|
58 |
+
# Filter out unavailable models
|
59 |
+
MODEL_CONFIGS = {k: v for k, v in MODEL_CONFIGS.items() if v is not None}
|
60 |
+
|
61 |
+
# --- Model Loading Cache ---
|
62 |
+
# Use a simple cache to avoid reloading the same model multiple times
|
63 |
+
cached_model = {}
|
64 |
+
cached_model_name = None
|
65 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
66 |
+
print(f"Using device: {device}")
|
67 |
+
|
68 |
+
# Function to load the selected model
|
69 |
+
def load_model(model_name):
|
70 |
+
global cached_model, cached_model_name
|
71 |
+
|
72 |
+
if model_name == cached_model_name and cached_model is not None:
|
73 |
+
print(f"Using cached model: {model_name}")
|
74 |
+
return cached_model, MODEL_CONFIGS[model_name]['type']
|
75 |
+
|
76 |
+
print(f"Loading model: {model_name}")
|
77 |
+
config = MODEL_CONFIGS.get(model_name)
|
78 |
+
if config is None:
|
79 |
+
return None, f"Error: Model '{model_name}' not supported or dependencies missing."
|
80 |
+
|
81 |
+
try:
|
82 |
+
model_type = config['type']
|
83 |
+
model_path = hf_hub_download(repo_id=config['repo_id'], filename=config['filename'])
|
84 |
+
|
85 |
+
if model_name == "Real-ESRGAN-x4":
|
86 |
+
if not REALESRGAN_AVAILABLE: raise ImportError("realesrgan not installed.")
|
87 |
+
model = RealESRGAN(device, scale=config['scale'])
|
88 |
+
model.load_weights(model_path)
|
89 |
+
|
90 |
+
elif model_name == "SwinIR-4x":
|
91 |
+
if not BASESR_AVAILABLE: raise ImportError("basicsr not installed.")
|
92 |
+
# SwinIR requires specific initialization parameters
|
93 |
+
# These match the SwinIR_4x.pth model from the repo
|
94 |
+
model = SwinIR_Arch(
|
95 |
+
upscale=config['scale'], in_chans=3, img_size=64, window_size=8,
|
96 |
+
compress_ratio= -1, dilate_basis=-1, res_range=-1, attn_type='linear'
|
97 |
+
)
|
98 |
+
# Load weights, handling potential key mismatches if necessary
|
99 |
+
pretrained_dict = torch.load(model_path, map_location=device)
|
100 |
+
model.load_state_dict(pretrained_dict, strict=True) # strict=False if keys might mismatch
|
101 |
+
model.eval() # Set to evaluation mode
|
102 |
+
model.to(device)
|
103 |
+
|
104 |
+
elif model_name == "EDSR-x4":
|
105 |
+
if not BASESR_AVAILABLE: raise ImportError("basicsr not installed.")
|
106 |
+
# EDSR architecture needs scale, num_feat, num_block
|
107 |
+
# Assuming typical values for EDSR_x4 from the repo
|
108 |
+
model = EDSR_Arch(num_feat=64, num_block=16, upscale=config['scale'])
|
109 |
+
pretrained_dict = torch.load(model_path, map_location=device)
|
110 |
+
model.load_state_dict(pretrained_dict, strict=True)
|
111 |
+
model.eval()
|
112 |
+
model.to(device)
|
113 |
+
|
114 |
+
elif model_name == "CodeFormer (Face Enhancement)":
|
115 |
+
if not (CODEFORMER_AVAILABLE or BASESR_AVAILABLE): raise ImportError("CodeFormer or basicsr not installed.")
|
116 |
+
# CodeFormer loading is more complex, often requiring instantiation with specific args
|
117 |
+
# and potentially related models (like GFPGAN for background).
|
118 |
+
# For simplicity here, we assume a basic CodeFormer instance can be created.
|
119 |
+
# This part might need adjustment based on your CodeFormer installation.
|
120 |
+
if CODEFORMER_AVAILABLE:
|
121 |
+
# This is a simplified instantiation; a real CodeFormer usage might need more args
|
122 |
+
model = CodeFormer(dim_embd=512, codebook_size=1024, n_head=8, n_layer=9)
|
123 |
+
pretrained_dict = torch.load(model_path, map_location=device)['params_ema'] # CodeFormer often saves params_ema
|
124 |
+
# Need to handle potential DataParallel prefix if saved from DP
|
125 |
+
keys = list(pretrained_dict.keys())
|
126 |
+
if keys and keys[0].startswith('module.'):
|
127 |
+
pretrained_dict = {k[7:]: v for k, v in pretrained_dict.items()}
|
128 |
+
model.load_state_dict(pretrained_dict, strict=True)
|
129 |
+
model.eval()
|
130 |
+
model.to(device)
|
131 |
+
elif BASESR_AVAILABLE:
|
132 |
+
# Fallback: If CodeFormer library isn't directly importable but basicsr is,
|
133 |
+
# we *cannot* instantiate the CodeFormer model itself unless basicsr provides it.
|
134 |
+
# This option is likely only possible if CodeFormer is installed *within* a basicsr environment
|
135 |
+
# or if basicsr provides the architecture. Given the complexity, let's just raise an error
|
136 |
+
# if CODEFORMER_AVAILABLE is False.
|
137 |
+
raise ImportError("CodeFormer library not found. BasicSR utilities alone are not enough to instantiate CodeFormer.")
|
138 |
+
|
139 |
+
|
140 |
+
else:
|
141 |
+
raise ValueError(f"Configuration missing for model: {model_name}")
|
142 |
+
|
143 |
+
# Cache the loaded model
|
144 |
+
cached_model = model
|
145 |
+
cached_model_name = model_name
|
146 |
+
|
147 |
+
return model, model_type
|
148 |
+
|
149 |
+
except ImportError as ie:
|
150 |
+
print(f"Dependency missing for {model_name}: {ie}")
|
151 |
+
return None, f"Error: Missing dependency - {ie}. Please ensure model libraries are installed."
|
152 |
+
except Exception as e:
|
153 |
+
print(f"Error loading model {model_name}: {e}")
|
154 |
+
# Clear cache on error
|
155 |
+
cached_model = None
|
156 |
+
cached_model_name = None
|
157 |
+
return None, f"Error loading model: {str(e)}"
|
158 |
+
|
159 |
+
# Function to preprocess image (PIL RGB to OpenCV BGR numpy)
|
160 |
+
def preprocess_image(image: Image.Image) -> np.ndarray:
|
161 |
+
img = np.array(image)
|
162 |
+
# OpenCV uses BGR, PIL uses RGB. Need conversion.
|
163 |
+
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
164 |
+
return img
|
165 |
+
|
166 |
+
# Function to postprocess image (OpenCV BGR numpy to PIL RGB)
|
167 |
+
def postprocess_image(img: np.ndarray) -> Image.Image:
|
168 |
+
# Ensure image is in the correct range and type before converting
|
169 |
+
if img.dtype != np.uint8:
|
170 |
+
img = np.clip(img, 0, 255).astype(np.uint8)
|
171 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
172 |
+
return Image.fromarray(img)
|
173 |
+
|
174 |
+
# Main processing function
|
175 |
+
def enhance_image(image: Image.Image, model_name: str):
|
176 |
+
if image is None:
|
177 |
+
return "Please upload an image.", None
|
178 |
+
|
179 |
+
status_message = f"Processing image with {model_name}..."
|
180 |
+
|
181 |
+
# Load the selected model and its type
|
182 |
+
model, model_info = load_model(model_name)
|
183 |
+
|
184 |
+
if model is None:
|
185 |
+
# model_info contains the error message if loading failed
|
186 |
+
return model_info, None
|
187 |
+
|
188 |
+
model_type = model_info # model_info is the type string ('upscale' or 'face')
|
189 |
+
|
190 |
+
try:
|
191 |
+
# Preprocess the image (PIL RGB -> OpenCV BGR)
|
192 |
+
img_np_bgr = preprocess_image(image)
|
193 |
+
|
194 |
+
# Process based on model type and specific model implementation
|
195 |
+
if model_type == "upscale":
|
196 |
+
print(f"Applying {model_name} upscaling...")
|
197 |
+
if model_name == "Real-ESRGAN-x4":
|
198 |
+
# RealESRGAN works with uint8 BGR numpy directly
|
199 |
+
output_np_bgr = model.predict(img_np_bgr)
|
200 |
+
elif model_name in ["SwinIR-4x", "EDSR-x4"]:
|
201 |
+
if not BASESR_AVAILABLE:
|
202 |
+
raise ImportError(f"basicsr is required for {model_name}")
|
203 |
+
# These models often work with float tensors (0-1 range)
|
204 |
+
# Using basicsr utils: HWC BGR uint8 -> CHW RGB float (0-1) -> send to device
|
205 |
+
img_tensor = img2tensor(img_np_bgr.astype(np.float32) / 255., bgr2rgb=True, float32=True).unsqueeze(0).to(device)
|
206 |
+
|
207 |
+
with torch.no_grad():
|
208 |
+
output_tensor = model(img_tensor)
|
209 |
+
|
210 |
+
# Using basicsr utils: CHW RGB float (0-1) -> HWC RGB uint8 -> Convert to BGR for postprocessing
|
211 |
+
output_np_rgb = tensor2img(output_tensor, rgb2bgr=False, min_max=(0, 1))
|
212 |
+
output_np_bgr = cv2.cvtColor(output_np_rgb, cv2.COLOR_RGB2BGR)
|
213 |
+
|
214 |
+
else:
|
215 |
+
raise ValueError(f"Unknown upscale model: {model_name}")
|
216 |
+
|
217 |
+
status_message = f"Image upscaled successfully with {model_name}!"
|
218 |
+
|
219 |
+
elif model_type == "face":
|
220 |
+
print(f"Applying {model_name} face enhancement...")
|
221 |
+
if model_name == "CodeFormer (Face Enhancement)":
|
222 |
+
if not (CODEFORMER_AVAILABLE or BASESR_AVAILABLE):
|
223 |
+
raise ImportError(f"CodeFormer or basicsr is required for {model_name}")
|
224 |
+
# CodeFormer's enhance method typically expects uint8 BGR numpy
|
225 |
+
# It might return multiple outputs, the first is usually the enhanced image
|
226 |
+
# Example: CodeFormer's inference script might return (restored_img, bboxes)
|
227 |
+
# We assume the image is the first element.
|
228 |
+
# Note: CodeFormer often needs additional setup/parameters for GFPGAN, etc.
|
229 |
+
# This is a simplified call.
|
230 |
+
# Ensure model is on correct device before call
|
231 |
+
if next(model.parameters()).device != device:
|
232 |
+
model.to(device)
|
233 |
+
|
234 |
+
# A minimal CodeFormer enhancement might look like this, but the actual API
|
235 |
+
# depends on the CodeFormer library version/structure you're using.
|
236 |
+
# The original CodeFormer repo's inference takes numpy BGR.
|
237 |
+
# This is a *placeholder* call assuming such a method exists and works like this:
|
238 |
+
output_np_bgr = model.enhance(img_np_bgr, w=0.5, adain=True)[0] # w and adain are common params
|
239 |
+
|
240 |
+
|
241 |
+
else:
|
242 |
+
raise ValueError(f"Unknown face enhancement model: {model_name}")
|
243 |
+
|
244 |
+
status_message = f"Face enhancement applied successfully with {model_name}!"
|
245 |
+
|
246 |
+
# Postprocess the output image (OpenCV BGR -> PIL RGB)
|
247 |
+
enhanced_image = postprocess_image(output_np_bgr)
|
248 |
+
|
249 |
+
return status_message, enhanced_image
|
250 |
+
|
251 |
+
except ImportError as ie:
|
252 |
+
return f"Error processing image: Missing dependency - {ie}", None
|
253 |
+
except Exception as e:
|
254 |
+
print(f"Error during processing: {e}")
|
255 |
+
import traceback
|
256 |
+
traceback.print_exc() # Print full traceback for debugging
|
257 |
+
return f"Error processing image: {str(e)}", None
|
258 |
+
|
259 |
+
# Gradio interface
|
260 |
+
with gr.Blocks(title="Image Upscale & Enhancement - By FebryEnsz") as demo:
|
261 |
+
gr.Markdown(
|
262 |
+
"""
|
263 |
+
# Image Upscale & Enhancement
|
264 |
+
**By FebryEnsz**
|
265 |
+
|
266 |
+
Upload an image and select a model to enhance it. Choose from multiple models for upscaling (to make it 'HD' or higher resolution) or face enhancement (to improve facial details and focus).
|
267 |
+
|
268 |
+
**Note:** This app requires specific Python libraries (`torch`, `basicsr`, `realesrgan`, `CodeFormer`) to be installed for all models to be available. If a model option is missing, its required library is not installed or found.
|
269 |
+
"""
|
270 |
+
)
|
271 |
+
|
272 |
+
with gr.Row():
|
273 |
+
with gr.Column():
|
274 |
+
image_input = gr.Image(label="Upload Image", type="pil")
|
275 |
+
|
276 |
+
# Filter available choices based on loaded configs
|
277 |
+
available_models = list(MODEL_CONFIGS.keys())
|
278 |
+
|
279 |
+
if not available_models:
|
280 |
+
model_choice = gr.Textbox(label="Select Model", value="No models available. Check dependencies.", interactive=False)
|
281 |
+
enhance_button = gr.Button("Enhance Image", interactive=False)
|
282 |
+
print("No models are available because dependencies are missing.")
|
283 |
+
else:
|
284 |
+
model_choice = gr.Dropdown(
|
285 |
+
choices=available_models,
|
286 |
+
label="Select Model",
|
287 |
+
value=available_models[0] # Default to the first available model
|
288 |
+
)
|
289 |
+
# Removed scale_slider as models are fixed scale (x4)
|
290 |
+
enhance_button = gr.Button("Enhance Image")
|
291 |
+
|
292 |
+
with gr.Column():
|
293 |
+
output_text = gr.Textbox(label="Status", max_lines=2)
|
294 |
+
output_image = gr.Image(label="Enhanced Image")
|
295 |
+
|
296 |
+
# Connect the button to the processing function
|
297 |
+
if available_models: # Only connect if models are available
|
298 |
+
enhance_button.click(
|
299 |
+
fn=enhance_image,
|
300 |
+
inputs=[image_input, model_choice],
|
301 |
+
outputs=[output_text, output_image]
|
302 |
+
)
|
303 |
+
|
304 |
+
# Launch the Gradio app
|
305 |
+
if __name__ == "__main__":
|
306 |
+
# Set torch backend for potentially better performance on some systems
|
307 |
+
if torch.backends.mps.is_available():
|
308 |
+
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1' # Optional: enable fallback for MPS
|
309 |
+
|
310 |
+
demo.launch()
|