Spaces:
Running
Running
First Commit
Browse files- app.py +268 -0
- generator_model.py +100 -0
- requirements.txt +8 -0
- run_on_patches_online.py +258 -0
app.py
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
import time
|
4 |
+
from PIL import Image, UnidentifiedImageError
|
5 |
+
import os
|
6 |
+
import requests
|
7 |
+
import io
|
8 |
+
import uuid
|
9 |
+
from pathlib import Path
|
10 |
+
from huggingface_hub import hf_hub_download
|
11 |
+
|
12 |
+
|
13 |
+
# Assuming these are correctly imported from your processing script
|
14 |
+
from run_on_patches_online import (
|
15 |
+
load_model,
|
16 |
+
process_image_from_data,
|
17 |
+
Generator, # Make sure Generator is imported if needed by load_model
|
18 |
+
DEVICE,
|
19 |
+
#CHECKPOINT_GEN,
|
20 |
+
PATCH_KERNEL_SIZE,
|
21 |
+
PATCH_STRIDE
|
22 |
+
)
|
23 |
+
|
24 |
+
# --- Global Variables & Model Loading ---
|
25 |
+
HF_REPO_ID = "b-aryan/WM-rem-epoch-42"
|
26 |
+
HF_FILENAME = "gen_epoch_42.pth.tar"
|
27 |
+
CHECKPOINT_GEN = HF_FILENAME
|
28 |
+
MODEL = None
|
29 |
+
MODEL_LOAD_ERROR = None
|
30 |
+
DOWNLOADED_CHECKPOINT_PATH = None # Store the path after download
|
31 |
+
|
32 |
+
print(f"Attempting to download/load model '{HF_FILENAME}' from repo '{HF_REPO_ID}' onto device '{DEVICE}'...")
|
33 |
+
|
34 |
+
try:
|
35 |
+
# 1. Download the model checkpoint from Hugging Face Hub
|
36 |
+
print(f"Downloading checkpoint '{HF_FILENAME}' from '{HF_REPO_ID}'...")
|
37 |
+
DOWNLOADED_CHECKPOINT_PATH = hf_hub_download(
|
38 |
+
repo_id=HF_REPO_ID,
|
39 |
+
filename=HF_FILENAME
|
40 |
+
# cache_dir can be specified if needed, otherwise uses default HF cache
|
41 |
+
)
|
42 |
+
print(f"Checkpoint downloaded successfully to: {DOWNLOADED_CHECKPOINT_PATH}")
|
43 |
+
|
44 |
+
# 2. Load the model using the downloaded path
|
45 |
+
if not os.path.exists(DOWNLOADED_CHECKPOINT_PATH):
|
46 |
+
# This should ideally not happen if hf_hub_download succeeded
|
47 |
+
raise FileNotFoundError(f"Downloaded checkpoint file not found at: {DOWNLOADED_CHECKPOINT_PATH}")
|
48 |
+
|
49 |
+
MODEL = load_model(DOWNLOADED_CHECKPOINT_PATH, DEVICE)
|
50 |
+
print("Model loaded successfully for Gradio app.")
|
51 |
+
|
52 |
+
except Exception as e:
|
53 |
+
MODEL_LOAD_ERROR = f"Failed to download or load model '{HF_FILENAME}' from '{HF_REPO_ID}'. Error: {e}"
|
54 |
+
print(f"Error: {MODEL_LOAD_ERROR}")
|
55 |
+
import traceback
|
56 |
+
traceback.print_exc()
|
57 |
+
|
58 |
+
TEMP_DIR = Path("temp")
|
59 |
+
TEMP_DIR.mkdir(exist_ok=True)
|
60 |
+
|
61 |
+
# --- Helper Function: Download Image (Simplified from run_on_patches_online) ---
|
62 |
+
def download_image_for_gradio(url: str, timeout: int = 20) -> Image.Image | None:
|
63 |
+
"""Downloads an image from a URL for Gradio, returns PIL Image or raises gr.Error."""
|
64 |
+
print(f"Attempting to download image from: {url}")
|
65 |
+
if not url or not url.startswith(('http://', 'https://')):
|
66 |
+
raise gr.Error("Please enter a valid HTTP or HTTPS image URL.")
|
67 |
+
|
68 |
+
try:
|
69 |
+
headers = {'User-Agent': 'Gradio-Image-Processor/1.1'}
|
70 |
+
response = requests.get(url, stream=True, timeout=timeout, headers=headers)
|
71 |
+
response.raise_for_status()
|
72 |
+
|
73 |
+
content_type = response.headers.get('Content-Type', '').lower()
|
74 |
+
if not content_type.startswith('image/'):
|
75 |
+
raise gr.Error(f"URL content type ({content_type}) is not recognized as an image.")
|
76 |
+
|
77 |
+
# Limit download size (e.g., 20 MB) to prevent abuse
|
78 |
+
content_length = response.headers.get('Content-Length')
|
79 |
+
if content_length and int(content_length) > 20 * 1024 * 1024:
|
80 |
+
raise gr.Error(f"Image file size ({int(content_length)/1024/1024:.1f} MB) exceeds the 20 MB limit.")
|
81 |
+
|
82 |
+
image_bytes = response.content
|
83 |
+
pil_image = Image.open(io.BytesIO(image_bytes))
|
84 |
+
pil_image = pil_image.convert('RGB')
|
85 |
+
print(f"Image downloaded successfully ({pil_image.width}x{pil_image.height}).")
|
86 |
+
|
87 |
+
# Optional: Add image dimension limits if needed
|
88 |
+
# max_dim = 2048
|
89 |
+
# if pil_image.width > max_dim or pil_image.height > max_dim:
|
90 |
+
# raise gr.Error(f"Image dimensions ({pil_image.width}x{pil_image.height}) exceed the maximum allowed ({max_dim}x{max_dim}).")
|
91 |
+
|
92 |
+
return pil_image
|
93 |
+
|
94 |
+
except requests.exceptions.Timeout:
|
95 |
+
raise gr.Error(f"Request timed out after {timeout} seconds trying to download the image.")
|
96 |
+
except requests.exceptions.RequestException as e:
|
97 |
+
raise gr.Error(f"Error downloading image: {e}")
|
98 |
+
except UnidentifiedImageError:
|
99 |
+
raise gr.Error("Could not identify image file. The URL might not point to a valid image.")
|
100 |
+
except Exception as e:
|
101 |
+
print(f"An unexpected error occurred during download: {e}") # Log for server admin
|
102 |
+
raise gr.Error(f"An unexpected error occurred during image download.")
|
103 |
+
|
104 |
+
|
105 |
+
# --- Processing Function (Handles the ML part with progress) ---
|
106 |
+
def run_processing(input_pil_image: Image.Image, progress=gr.Progress(track_tqdm=True)):
|
107 |
+
"""Processes the already downloaded PIL image and returns the result."""
|
108 |
+
if MODEL is None:
|
109 |
+
# Include the more specific error message if loading failed
|
110 |
+
error_msg = f"Model is not loaded. Cannot process image. Load error: {MODEL_LOAD_ERROR}" if MODEL_LOAD_ERROR else "Model is not loaded. Cannot process image."
|
111 |
+
raise gr.Error(error_msg)
|
112 |
+
if input_pil_image is None:
|
113 |
+
return None, None
|
114 |
+
|
115 |
+
start_time = time.time()
|
116 |
+
print("Starting image processing...")
|
117 |
+
progress(0, desc="Preparing for processing...")
|
118 |
+
|
119 |
+
try:
|
120 |
+
output_pil_image = process_image_from_data(
|
121 |
+
input_pil_image=input_pil_image,
|
122 |
+
model=MODEL,
|
123 |
+
device=DEVICE,
|
124 |
+
kernel_size=PATCH_KERNEL_SIZE,
|
125 |
+
stride=PATCH_STRIDE,
|
126 |
+
use_tqdm=True
|
127 |
+
)
|
128 |
+
if output_pil_image is None:
|
129 |
+
raise gr.Error("Image processing failed internally. Check server logs.")
|
130 |
+
|
131 |
+
# Save the processed image with a unique filename
|
132 |
+
filename = f"processed_{uuid.uuid4().hex}.jpg"
|
133 |
+
save_path = TEMP_DIR / filename
|
134 |
+
output_pil_image.save(save_path)
|
135 |
+
print(f"Saved processed image to: {save_path}")
|
136 |
+
|
137 |
+
except Exception as e:
|
138 |
+
print(f"Exception during processing: {e}")
|
139 |
+
import traceback
|
140 |
+
traceback.print_exc()
|
141 |
+
raise gr.Error(f"An error occurred during model processing: {e}")
|
142 |
+
|
143 |
+
end_time = time.time()
|
144 |
+
processing_time = end_time - start_time
|
145 |
+
print(f"Processing time: {processing_time:.2f} seconds")
|
146 |
+
|
147 |
+
return output_pil_image, str(save_path)
|
148 |
+
|
149 |
+
# --- Wrapper Function for Button Click ---
|
150 |
+
def download_and_clear(url):
|
151 |
+
"""Calls download, returns (image, None) on success, raises gr.Error on failure."""
|
152 |
+
print("Download and clear triggered.") # Debug print
|
153 |
+
try:
|
154 |
+
img = download_image_for_gradio(url)
|
155 |
+
# Return the downloaded image for the original slot, and None for the processed slot
|
156 |
+
return img, None
|
157 |
+
except gr.Error as e:
|
158 |
+
# If download fails, raise the Gradio error again
|
159 |
+
# Gradio will handle clearing outputs and displaying the error message
|
160 |
+
print(f"Download failed: {e}") # Debug print
|
161 |
+
raise e
|
162 |
+
except Exception as e:
|
163 |
+
# Catch other unexpected errors during the wrapper logic
|
164 |
+
print(f"Unexpected error in download_and_clear: {e}") # Debug print
|
165 |
+
import traceback
|
166 |
+
traceback.print_exc()
|
167 |
+
# Raise a Gradio error so the user sees something meaningful
|
168 |
+
raise gr.Error(f"An unexpected error occurred preparing the image: {e}")
|
169 |
+
|
170 |
+
|
171 |
+
# --- Gradio Blocks Interface Definition ---
|
172 |
+
print("Setting up Gradio Blocks interface...")
|
173 |
+
|
174 |
+
with gr.Blocks() as demo:
|
175 |
+
saved_image_path = gr.State()
|
176 |
+
gr.Markdown(f"""
|
177 |
+
# Watermark Remover
|
178 |
+
|
179 |
+
**Disclaimer: This project was made to showcase my Deep Learning skills with no intention to cause harm to any business or infringe on any IP and it will be decommissioned as soon as I get a decent job offer.
|
180 |
+
If you still have any issue, please reach out to me at [[email protected]](mailto:[email protected]).**
|
181 |
+
""")
|
182 |
+
|
183 |
+
gr.Markdown(f"""
|
184 |
+
This is a demo of a DL model which takes in URL of an image with watermark and gives out the image with its watermark removed.
|
185 |
+
The image is broken into overlapping patches of 256x256 pixels with a stride of 64 before feeding them to the model,
|
186 |
+
the model infers on each patch separately and then they are stitched together to form the whole image again to give the final output.
|
187 |
+
""")
|
188 |
+
|
189 |
+
if MODEL_LOAD_ERROR:
|
190 |
+
gr.Markdown(f"<font color='red'>**Model Loading Error:** {MODEL_LOAD_ERROR}</font>")
|
191 |
+
|
192 |
+
with gr.Row():
|
193 |
+
with gr.Column(scale=1):
|
194 |
+
image_url = gr.Textbox(
|
195 |
+
label="Image URL",
|
196 |
+
placeholder="Paste image URL (https://www.shutterstock.com/shutterstock/photos/...)",
|
197 |
+
info="Enter the direct link to a publicly accessible image file (jpg, png, etc.)."
|
198 |
+
)
|
199 |
+
submit_button = gr.Button("Process Image", variant="primary")
|
200 |
+
gr.Markdown("---") # Separator
|
201 |
+
gr.Examples(
|
202 |
+
examples=[
|
203 |
+
["https://www.shutterstock.com/shutterstock/photos/2429728793/display_1500/stock-photo-monkey-funny-open-mouth-and-showing-teeth-crazy-nature-angry-short-hair-brown-grand-bassin-2429728793.jpg"],
|
204 |
+
["https://www.shutterstock.com/shutterstock/photos/1905929728/display_1500/stock-photo-skeptic-surprised-cat-thinking-dont-know-what-to-do-big-eyes-closeup-tabby-cat-look-side-dont-1905929728.jpg"],
|
205 |
+
["https://www.shutterstock.com/shutterstock/photos/2501926843/display_1500/stock-photo-brunette-woman-laying-on-couch-cuddling-light-brown-dog-and-brown-tabby-cat-happy-2501926843.jpg"]
|
206 |
+
],
|
207 |
+
inputs=image_url # Link examples to the URL textbox
|
208 |
+
)
|
209 |
+
|
210 |
+
with gr.Column(scale=2):
|
211 |
+
with gr.Row():
|
212 |
+
original_image = gr.Image(label="Original Input Image", type="pil", interactive=False)
|
213 |
+
processed_image = gr.Image(label="Processed Output Image", type="pil", interactive=False)
|
214 |
+
download_button = gr.DownloadButton("Download Processed Image", visible=True, variant="secondary")
|
215 |
+
gr.Markdown("---") # Separator
|
216 |
+
gr.Markdown(f"""
|
217 |
+
**A bit about the model:** In this project, I have trained a GAN network,
|
218 |
+
with the Generator being inspired from Pix2Pix and Pix2PixHD architectures and the Discriminator is very similar to PatchGAN in Pix2Pix and CycleGAN.
|
219 |
+
For the loss, I have also added Perceptual Loss using VGG like in Pix2PixHD and SRGAN papers apart from the L1 and BCE loss.
|
220 |
+
""")
|
221 |
+
|
222 |
+
gr.Markdown(f"""If you liked this project, you can find my CV [here]() or reach me out at [[email protected]](mailto:[email protected]).""")
|
223 |
+
|
224 |
+
|
225 |
+
# --- Event Handling Logic ---
|
226 |
+
|
227 |
+
# 1. When Button is clicked:
|
228 |
+
# - Input: URL from textbox
|
229 |
+
# - Action: Call 'download_and_clear' wrapper function
|
230 |
+
# - Output: Update 'original_image' with downloaded image, clear 'processed_image' by returning None for it.
|
231 |
+
submit_button.click(
|
232 |
+
fn=download_and_clear, # <--- CORRECTED: Use the wrapper function
|
233 |
+
inputs=image_url,
|
234 |
+
outputs=[original_image, processed_image] # Target both outputs
|
235 |
+
)
|
236 |
+
|
237 |
+
# 2. When 'original_image' component *changes* (i.e., after successful download):
|
238 |
+
# - Input: The PIL image data from 'original_image'
|
239 |
+
# - Action: Call 'run_processing' (this function has the progress bar)
|
240 |
+
# - Output: Update the 'processed_image' component.
|
241 |
+
original_image.change(
|
242 |
+
fn=run_processing,
|
243 |
+
inputs=original_image,
|
244 |
+
outputs=[processed_image, saved_image_path]
|
245 |
+
# concurrency_limit=1 # Optional: Prevent multiple simultaneous processing runs if needed
|
246 |
+
)
|
247 |
+
|
248 |
+
processed_image.change(
|
249 |
+
fn=lambda img: gr.update(variant="primary" if img is not None else "secondary"),
|
250 |
+
inputs=processed_image,
|
251 |
+
outputs=download_button
|
252 |
+
)
|
253 |
+
|
254 |
+
download_button.click(
|
255 |
+
fn=lambda path: path if path else None,
|
256 |
+
inputs=saved_image_path,
|
257 |
+
outputs=download_button
|
258 |
+
)
|
259 |
+
|
260 |
+
|
261 |
+
|
262 |
+
# --- Launch the Application ---
|
263 |
+
if __name__ == "__main__":
|
264 |
+
print("Launching Gradio Blocks interface...")
|
265 |
+
# Set queue=True for better handling under load, especially with long-running processing
|
266 |
+
# demo.queue() # Consider adding this
|
267 |
+
demo.launch(share=False, server_name="0.0.0.0", show_api=False)
|
268 |
+
|
generator_model.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
class ConvBlock(nn.Module):
|
6 |
+
def __init__(self, in_channels, out_channels, use_act=True, **kwargs):
|
7 |
+
super().__init__()
|
8 |
+
self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=False, padding_mode="reflect")
|
9 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
10 |
+
self.act = nn.ReLU(inplace=True) if use_act else nn.Identity()
|
11 |
+
|
12 |
+
def forward(self, x):
|
13 |
+
return self.act(self.bn(self.cnn(x)))
|
14 |
+
|
15 |
+
class ResidualBlock(nn.Module):
|
16 |
+
def __init__(self, in_channels):
|
17 |
+
super().__init__()
|
18 |
+
self.survival_prob = 0.8
|
19 |
+
self.block1 = ConvBlock(
|
20 |
+
in_channels,
|
21 |
+
in_channels,
|
22 |
+
kernel_size=3,
|
23 |
+
stride=1,
|
24 |
+
padding=1,
|
25 |
+
)
|
26 |
+
self.block2 = ConvBlock(
|
27 |
+
in_channels,
|
28 |
+
in_channels,
|
29 |
+
kernel_size=3,
|
30 |
+
stride=1,
|
31 |
+
padding=1,
|
32 |
+
use_act=True,
|
33 |
+
)
|
34 |
+
|
35 |
+
def stochastic_depth(self, x):
|
36 |
+
if not self.training:
|
37 |
+
return x
|
38 |
+
binary_tensor = torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob
|
39 |
+
return torch.div(x, self.survival_prob) * binary_tensor
|
40 |
+
|
41 |
+
def forward(self, x):
|
42 |
+
out = self.block1(x)
|
43 |
+
out = self.block2(out)
|
44 |
+
return self.stochastic_depth(out) + x
|
45 |
+
|
46 |
+
class Block(nn.Module):
|
47 |
+
def __init__(self, in_channels, out_channels, stride=2, act="relu"):
|
48 |
+
super().__init__()
|
49 |
+
self.conv = nn.Sequential(
|
50 |
+
nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False, padding_mode="reflect"),
|
51 |
+
nn.BatchNorm2d(out_channels),
|
52 |
+
nn.ReLU(inplace=True) if act == "relu" else nn.LeakyReLU(0.2, inplace=True),
|
53 |
+
)
|
54 |
+
|
55 |
+
def forward(self, x):
|
56 |
+
return self.conv(x)
|
57 |
+
|
58 |
+
class Generator(nn.Module):
|
59 |
+
def __init__(self, in_channels=3, features=64, num_residuals=9):
|
60 |
+
super().__init__()
|
61 |
+
self.initial_down = nn.Sequential(
|
62 |
+
nn.Conv2d(in_channels, features, 7, 1, 3, bias=True, padding_mode="reflect"),
|
63 |
+
nn.ReLU(inplace=True),
|
64 |
+
)
|
65 |
+
self.down1 = Block(features, features*2, act="relu")
|
66 |
+
self.down2 = Block(features*2, features*4, act="relu")
|
67 |
+
self.down3 = Block(features*4, features*8, act="relu")
|
68 |
+
self.down4 = Block(features*8, features*16, act="relu")
|
69 |
+
self.residuals = nn.Sequential(*[ResidualBlock(features*16) for _ in range(num_residuals)])
|
70 |
+
self.up1 = Block(features*16, features*8, stride=1, act="relu")
|
71 |
+
self.up2 = Block(features*8*2, features*4, stride=1, act="relu" )
|
72 |
+
self.up3 = Block(features*4*2, features*2, stride=1, act="relu")
|
73 |
+
self.up4 = Block(features*2*2, features, stride=1, act="relu")
|
74 |
+
self.final_conv = nn.Sequential(
|
75 |
+
Block(features*2, features, stride=1, act="relu"),
|
76 |
+
Block(features, features, stride=1, act="relu"),
|
77 |
+
nn.Conv2d(features, in_channels, 7,1,3, padding_mode="reflect"),
|
78 |
+
nn.Tanh(),
|
79 |
+
)
|
80 |
+
|
81 |
+
def forward(self, x):
|
82 |
+
d1 = self.initial_down(x)
|
83 |
+
d2 = self.down1(d1)
|
84 |
+
d3 = self.down2(d2)
|
85 |
+
d4 = self.down3(d3)
|
86 |
+
d5 = self.down4(d4)
|
87 |
+
residuals = self.residuals(d5) + d5
|
88 |
+
u1 = self.up1(F.interpolate(residuals, scale_factor=2, mode="nearest"))
|
89 |
+
u2 = self.up2(F.interpolate(torch.cat([u1, d4], dim=1), scale_factor=2, mode="nearest"))
|
90 |
+
u3 = self.up3(F.interpolate(torch.cat([u2, d3], dim=1), scale_factor=2, mode="nearest"))
|
91 |
+
u4 = self.up4(F.interpolate(torch.cat([u3, d2], dim=1), scale_factor=2, mode="nearest"))
|
92 |
+
return self.final_conv(torch.cat([u4, d1], dim=1))
|
93 |
+
|
94 |
+
def test():
|
95 |
+
x = torch.randn((1, 3, 256, 256))
|
96 |
+
model = Generator(in_channels=3, features=64)
|
97 |
+
preds = model(x)
|
98 |
+
print(preds.shape)
|
99 |
+
|
100 |
+
#test()
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio~=5.25.2
|
2 |
+
requests~=2.32.3
|
3 |
+
pillow~=11.1.0
|
4 |
+
numpy~=1.26.4
|
5 |
+
albumentations~=2.0.5
|
6 |
+
tqdm~=4.67.1
|
7 |
+
torch~=2.6.0
|
8 |
+
huggingface_hub~=0.30.2
|
run_on_patches_online.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from PIL import Image, UnidentifiedImageError
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import albumentations as A
|
7 |
+
from albumentations.pytorch import ToTensorV2
|
8 |
+
from collections import OrderedDict
|
9 |
+
from tqdm import tqdm
|
10 |
+
import requests
|
11 |
+
import io
|
12 |
+
from generator_model import Generator
|
13 |
+
|
14 |
+
# --- Constants ---
|
15 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
+
CHECKPOINT_GEN = "gen_epoch_42.pth.tar" # Keep your checkpoint name
|
17 |
+
PATCH_KERNEL_SIZE = 256
|
18 |
+
PATCH_STRIDE = 64
|
19 |
+
# DEFAULT_INPUT_DIR = "test/inputs" # No longer needed for Gradio URL input
|
20 |
+
# DEFAULT_OUTPUT_DIR = "test/outputs" # Output handled by Gradio
|
21 |
+
SUPPORTED_EXTENSIONS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff') # Still useful for local testing if needed
|
22 |
+
|
23 |
+
test_transform = A.Compose([
|
24 |
+
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
|
25 |
+
ToTensorV2()
|
26 |
+
])
|
27 |
+
|
28 |
+
def load_model(checkpoint_path: str, device: str) -> Generator:
|
29 |
+
print(f"Loading model from: {checkpoint_path} onto device: {device}")
|
30 |
+
model = Generator(in_channels=3, features=64).to(device)
|
31 |
+
checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
|
32 |
+
new_state_dict = OrderedDict()
|
33 |
+
has_module_prefix = any(k.startswith("module.") for k in checkpoint["state_dict"])
|
34 |
+
for k, v in checkpoint["state_dict"].items():
|
35 |
+
name = k.replace("module.", "") if has_module_prefix else k
|
36 |
+
new_state_dict[name] = v
|
37 |
+
model.load_state_dict(new_state_dict)
|
38 |
+
model.eval() # Set model to evaluation mode
|
39 |
+
print("Model loaded successfully.")
|
40 |
+
return model
|
41 |
+
|
42 |
+
def calculate_padding(img_h: int, img_w: int, kernel_size: int, stride: int) -> tuple[int, int]:
|
43 |
+
pad_h = kernel_size - img_h if img_h < kernel_size else (stride - (img_h - kernel_size) % stride) % stride
|
44 |
+
pad_w = kernel_size - img_w if img_w < kernel_size else (stride - (img_w - kernel_size) % stride) % stride
|
45 |
+
return pad_h, pad_w
|
46 |
+
|
47 |
+
def download_image(url: str, timeout: int = 15) -> Image.Image | None:
|
48 |
+
"""Downloads an image from a URL and returns it as a PIL Image object."""
|
49 |
+
print(f"Attempting to download image from: {url}")
|
50 |
+
try:
|
51 |
+
headers = {'User-Agent': 'Gradio-Image-Processor/1.0'} # Be a good net citizen
|
52 |
+
response = requests.get(url, stream=True, timeout=timeout, headers=headers)
|
53 |
+
response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
|
54 |
+
|
55 |
+
content_type = response.headers.get('Content-Type', '').lower()
|
56 |
+
if not content_type.startswith('image/'):
|
57 |
+
print(f"Error: URL content type ({content_type}) is not an image.")
|
58 |
+
return None
|
59 |
+
|
60 |
+
image_bytes = response.content
|
61 |
+
pil_image = Image.open(io.BytesIO(image_bytes))
|
62 |
+
pil_image = pil_image.convert('RGB') # Ensure image is in RGB format
|
63 |
+
print(f"Image downloaded successfully ({pil_image.width}x{pil_image.height}).")
|
64 |
+
return pil_image
|
65 |
+
|
66 |
+
except requests.exceptions.Timeout:
|
67 |
+
print(f"Error: Request timed out after {timeout} seconds.")
|
68 |
+
return None
|
69 |
+
except requests.exceptions.RequestException as e:
|
70 |
+
print(f"Error downloading image: {e}")
|
71 |
+
return None
|
72 |
+
except UnidentifiedImageError:
|
73 |
+
print("Error: Could not identify image file. The URL might not point to a valid image.")
|
74 |
+
return None
|
75 |
+
except Exception as e:
|
76 |
+
print(f"An unexpected error occurred during download: {e}")
|
77 |
+
return None
|
78 |
+
|
79 |
+
def process_image_from_data(
|
80 |
+
input_pil_image: Image.Image,
|
81 |
+
model: Generator,
|
82 |
+
device: str,
|
83 |
+
kernel_size: int,
|
84 |
+
stride: int,
|
85 |
+
use_tqdm: bool = True # Optional: Control progress bar visibility
|
86 |
+
) -> Image.Image | None:
|
87 |
+
"""
|
88 |
+
Processes an input PIL image using the patch-based method and returns the output PIL image.
|
89 |
+
Returns None if an error occurs during processing.
|
90 |
+
"""
|
91 |
+
print(f"\nProcessing image data...")
|
92 |
+
try:
|
93 |
+
image_np = np.array(input_pil_image) # Convert PIL Image to NumPy array
|
94 |
+
H, W, _ = image_np.shape
|
95 |
+
print(f" Input dimensions: {W}x{H}")
|
96 |
+
|
97 |
+
# Apply transformations
|
98 |
+
transformed = test_transform(image=image_np)
|
99 |
+
input_tensor = transformed['image'].to(device) # Shape: (C, H, W)
|
100 |
+
C = input_tensor.shape[0]
|
101 |
+
|
102 |
+
# Calculate and apply padding
|
103 |
+
pad_h, pad_w = calculate_padding(H, W, kernel_size, stride)
|
104 |
+
print(f" Calculated padding (H, W): ({pad_h}, {pad_w})")
|
105 |
+
padded_tensor = F.pad(input_tensor.unsqueeze(0), (0, pad_w, 0, pad_h), mode='reflect').squeeze(0)
|
106 |
+
_, H_pad, W_pad = padded_tensor.shape
|
107 |
+
print(f" Padded dimensions: {W_pad}x{H_pad}")
|
108 |
+
|
109 |
+
# Extract patches
|
110 |
+
patches = padded_tensor.unfold(1, kernel_size, stride).unfold(2, kernel_size, stride)
|
111 |
+
num_patches_h = patches.shape[1]
|
112 |
+
num_patches_w = patches.shape[2]
|
113 |
+
num_patches_total = num_patches_h * num_patches_w
|
114 |
+
print(f" Extracted {num_patches_total} patches ({num_patches_h} H x {num_patches_w} W)")
|
115 |
+
|
116 |
+
patches = patches.contiguous().view(C, -1, kernel_size, kernel_size)
|
117 |
+
# Permute to (num_patches_total, C, kernel_size, kernel_size)
|
118 |
+
patches = patches.permute(1, 0, 2, 3).contiguous()
|
119 |
+
|
120 |
+
output_patches = []
|
121 |
+
# Set up tqdm iterator if enabled
|
122 |
+
patch_iterator = tqdm(patches, total=num_patches_total, desc=" Inferring patches", unit="patch", leave=False, disable=not use_tqdm)
|
123 |
+
|
124 |
+
# --- Inference Loop ---
|
125 |
+
with torch.no_grad():
|
126 |
+
for patch in patch_iterator:
|
127 |
+
# Add batch dimension, run model, remove batch dimension
|
128 |
+
output_patch = model(patch.unsqueeze(0)).squeeze(0)
|
129 |
+
# Move to CPU immediately to save GPU memory during inference loop
|
130 |
+
output_patches.append(output_patch.cpu())
|
131 |
+
|
132 |
+
# Stack output patches back together
|
133 |
+
# If GPU memory allows, move back for reconstruction, otherwise keep on CPU
|
134 |
+
# Let's try moving back to device for faster reconstruction if possible
|
135 |
+
try:
|
136 |
+
output_patches = torch.stack(output_patches).to(device)
|
137 |
+
print(f" Output patches moved to {device} for reconstruction.")
|
138 |
+
except Exception as e: # Catch potential OOM on device
|
139 |
+
print(f" Warning: Could not move all output patches to {device} ({e}). Reconstruction might be slower on CPU.")
|
140 |
+
output_patches = torch.stack(output_patches) # Keep on CPU
|
141 |
+
|
142 |
+
|
143 |
+
# --- Reconstruction ---
|
144 |
+
# Generate 2D Hann window for blending
|
145 |
+
window_1d = torch.hann_window(kernel_size, periodic=False, device=device) # periodic=False is common
|
146 |
+
window_2d = torch.outer(window_1d, window_1d)
|
147 |
+
window_2d = window_2d.unsqueeze(0).to(device) # Add channel dim and ensure on device
|
148 |
+
|
149 |
+
# Initialize output tensor and weight tensor (for weighted averaging)
|
150 |
+
output_tensor = torch.zeros((C, H_pad, W_pad), device=device, dtype=output_patches.dtype)
|
151 |
+
weight_tensor = torch.zeros((C, H_pad, W_pad), device=device, dtype=window_2d.dtype)
|
152 |
+
|
153 |
+
patch_idx = 0
|
154 |
+
reconstruct_iterator = tqdm(total=num_patches_total, desc=" Reconstructing", unit="patch", leave=False, disable=not use_tqdm)
|
155 |
+
|
156 |
+
for i in range(num_patches_h):
|
157 |
+
for j in range(num_patches_w):
|
158 |
+
h_start = i * stride
|
159 |
+
w_start = j * stride
|
160 |
+
h_end = h_start + kernel_size
|
161 |
+
w_end = w_start + kernel_size
|
162 |
+
|
163 |
+
# Get current patch (ensure it's on the correct device)
|
164 |
+
current_patch = output_patches[patch_idx].to(device)
|
165 |
+
weighted_patch = current_patch * window_2d # Apply window
|
166 |
+
|
167 |
+
# Add weighted patch to output tensor
|
168 |
+
output_tensor[:, h_start:h_end, w_start:w_end] += weighted_patch
|
169 |
+
# Accumulate weights
|
170 |
+
weight_tensor[:, h_start:h_end, w_start:w_end] += window_2d
|
171 |
+
|
172 |
+
patch_idx += 1
|
173 |
+
reconstruct_iterator.update(1)
|
174 |
+
|
175 |
+
reconstruct_iterator.close() # Close the inner tqdm bar
|
176 |
+
|
177 |
+
# Perform weighted averaging - clamp weights to avoid division by zero
|
178 |
+
output_averaged = output_tensor / weight_tensor.clamp(min=1e-6)
|
179 |
+
|
180 |
+
# Crop to original dimensions
|
181 |
+
output_cropped = output_averaged[:, :H, :W]
|
182 |
+
print(f" Final output dimensions: {output_cropped.shape[2]}x{output_cropped.shape[1]}")
|
183 |
+
|
184 |
+
# --- Convert to Output Format ---
|
185 |
+
# Permute C, H, W -> H, W, C ; Move to CPU ; Convert to NumPy
|
186 |
+
output_numpy = output_cropped.permute(1, 2, 0).cpu().numpy()
|
187 |
+
|
188 |
+
# Denormalize: Assuming input was normalized to [-1, 1]
|
189 |
+
output_numpy = (output_numpy * 0.5 + 0.5) * 255.0
|
190 |
+
|
191 |
+
# Clip values to [0, 255] and convert to uint8
|
192 |
+
output_numpy = output_numpy.clip(0, 255).astype(np.uint8)
|
193 |
+
|
194 |
+
# Convert NumPy array back to PIL Image
|
195 |
+
output_image = Image.fromarray(output_numpy)
|
196 |
+
|
197 |
+
print(" Image processing complete.")
|
198 |
+
return output_image
|
199 |
+
|
200 |
+
except Exception as e:
|
201 |
+
print(f"Error during image processing: {e}")
|
202 |
+
import traceback
|
203 |
+
traceback.print_exc() # Print detailed traceback for debugging
|
204 |
+
return None
|
205 |
+
|
206 |
+
if __name__ == "__main__":
|
207 |
+
print("--- Testing Phase 1 Refactoring ---")
|
208 |
+
print(f"Using device: {DEVICE}")
|
209 |
+
print(f"Using patch kernel size: {PATCH_KERNEL_SIZE}")
|
210 |
+
print(f"Using patch stride: {PATCH_STRIDE}")
|
211 |
+
print(f"Using model checkpoint: {CHECKPOINT_GEN}")
|
212 |
+
|
213 |
+
# 1. Load the model (as it would be done globally in Gradio app)
|
214 |
+
try:
|
215 |
+
model = load_model(CHECKPOINT_GEN, DEVICE)
|
216 |
+
except Exception as e:
|
217 |
+
print(f"Failed to load model. Exiting test. Error: {e}")
|
218 |
+
exit()
|
219 |
+
|
220 |
+
|
221 |
+
# 2. Test URL download
|
222 |
+
# Replace with a valid image URL for testing
|
223 |
+
# test_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg"
|
224 |
+
test_url = "https://www.shutterstock.com/shutterstock/photos/2501926843/display_1500/stock-photo-brunette-woman-laying-on-couch-cuddling-light-brown-dog-and-brown-tabby-cat-happy-2501926843.jpg" # A smaller known image
|
225 |
+
input_pil = download_image(test_url)
|
226 |
+
|
227 |
+
if input_pil:
|
228 |
+
print(f"\nDownloaded image type: {type(input_pil)}, size: {input_pil.size}")
|
229 |
+
|
230 |
+
# 3. Test processing the downloaded image
|
231 |
+
output_pil = process_image_from_data(
|
232 |
+
input_pil_image=input_pil,
|
233 |
+
model=model,
|
234 |
+
device=DEVICE,
|
235 |
+
kernel_size=PATCH_KERNEL_SIZE,
|
236 |
+
stride=PATCH_STRIDE,
|
237 |
+
use_tqdm=True # Show progress bars during test
|
238 |
+
)
|
239 |
+
|
240 |
+
if output_pil:
|
241 |
+
print(f"\nProcessed image type: {type(output_pil)}, size: {output_pil.size}")
|
242 |
+
# Save the output locally for verification during testing
|
243 |
+
try:
|
244 |
+
os.makedirs("test_outputs", exist_ok=True)
|
245 |
+
output_filename = "test_output_" + os.path.basename(test_url).split('?')[0] # Basic filename extraction
|
246 |
+
if not output_filename.lower().endswith(SUPPORTED_EXTENSIONS):
|
247 |
+
output_filename += ".png" # Ensure it has an extension
|
248 |
+
save_path = os.path.join("test_outputs", output_filename)
|
249 |
+
output_pil.save(save_path)
|
250 |
+
print(f"Saved test output to: {save_path}")
|
251 |
+
except Exception as e:
|
252 |
+
print(f"Error saving test output: {e}")
|
253 |
+
else:
|
254 |
+
print("\nImage processing failed.")
|
255 |
+
else:
|
256 |
+
print("\nImage download failed.")
|
257 |
+
|
258 |
+
print("\n--- Phase 1 Testing Complete ---")
|