b-aryan commited on
Commit
be320a7
·
verified ·
1 Parent(s): 800e587

First Commit

Browse files
Files changed (4) hide show
  1. app.py +268 -0
  2. generator_model.py +100 -0
  3. requirements.txt +8 -0
  4. run_on_patches_online.py +258 -0
app.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import time
4
+ from PIL import Image, UnidentifiedImageError
5
+ import os
6
+ import requests
7
+ import io
8
+ import uuid
9
+ from pathlib import Path
10
+ from huggingface_hub import hf_hub_download
11
+
12
+
13
+ # Assuming these are correctly imported from your processing script
14
+ from run_on_patches_online import (
15
+ load_model,
16
+ process_image_from_data,
17
+ Generator, # Make sure Generator is imported if needed by load_model
18
+ DEVICE,
19
+ #CHECKPOINT_GEN,
20
+ PATCH_KERNEL_SIZE,
21
+ PATCH_STRIDE
22
+ )
23
+
24
+ # --- Global Variables & Model Loading ---
25
+ HF_REPO_ID = "b-aryan/WM-rem-epoch-42"
26
+ HF_FILENAME = "gen_epoch_42.pth.tar"
27
+ CHECKPOINT_GEN = HF_FILENAME
28
+ MODEL = None
29
+ MODEL_LOAD_ERROR = None
30
+ DOWNLOADED_CHECKPOINT_PATH = None # Store the path after download
31
+
32
+ print(f"Attempting to download/load model '{HF_FILENAME}' from repo '{HF_REPO_ID}' onto device '{DEVICE}'...")
33
+
34
+ try:
35
+ # 1. Download the model checkpoint from Hugging Face Hub
36
+ print(f"Downloading checkpoint '{HF_FILENAME}' from '{HF_REPO_ID}'...")
37
+ DOWNLOADED_CHECKPOINT_PATH = hf_hub_download(
38
+ repo_id=HF_REPO_ID,
39
+ filename=HF_FILENAME
40
+ # cache_dir can be specified if needed, otherwise uses default HF cache
41
+ )
42
+ print(f"Checkpoint downloaded successfully to: {DOWNLOADED_CHECKPOINT_PATH}")
43
+
44
+ # 2. Load the model using the downloaded path
45
+ if not os.path.exists(DOWNLOADED_CHECKPOINT_PATH):
46
+ # This should ideally not happen if hf_hub_download succeeded
47
+ raise FileNotFoundError(f"Downloaded checkpoint file not found at: {DOWNLOADED_CHECKPOINT_PATH}")
48
+
49
+ MODEL = load_model(DOWNLOADED_CHECKPOINT_PATH, DEVICE)
50
+ print("Model loaded successfully for Gradio app.")
51
+
52
+ except Exception as e:
53
+ MODEL_LOAD_ERROR = f"Failed to download or load model '{HF_FILENAME}' from '{HF_REPO_ID}'. Error: {e}"
54
+ print(f"Error: {MODEL_LOAD_ERROR}")
55
+ import traceback
56
+ traceback.print_exc()
57
+
58
+ TEMP_DIR = Path("temp")
59
+ TEMP_DIR.mkdir(exist_ok=True)
60
+
61
+ # --- Helper Function: Download Image (Simplified from run_on_patches_online) ---
62
+ def download_image_for_gradio(url: str, timeout: int = 20) -> Image.Image | None:
63
+ """Downloads an image from a URL for Gradio, returns PIL Image or raises gr.Error."""
64
+ print(f"Attempting to download image from: {url}")
65
+ if not url or not url.startswith(('http://', 'https://')):
66
+ raise gr.Error("Please enter a valid HTTP or HTTPS image URL.")
67
+
68
+ try:
69
+ headers = {'User-Agent': 'Gradio-Image-Processor/1.1'}
70
+ response = requests.get(url, stream=True, timeout=timeout, headers=headers)
71
+ response.raise_for_status()
72
+
73
+ content_type = response.headers.get('Content-Type', '').lower()
74
+ if not content_type.startswith('image/'):
75
+ raise gr.Error(f"URL content type ({content_type}) is not recognized as an image.")
76
+
77
+ # Limit download size (e.g., 20 MB) to prevent abuse
78
+ content_length = response.headers.get('Content-Length')
79
+ if content_length and int(content_length) > 20 * 1024 * 1024:
80
+ raise gr.Error(f"Image file size ({int(content_length)/1024/1024:.1f} MB) exceeds the 20 MB limit.")
81
+
82
+ image_bytes = response.content
83
+ pil_image = Image.open(io.BytesIO(image_bytes))
84
+ pil_image = pil_image.convert('RGB')
85
+ print(f"Image downloaded successfully ({pil_image.width}x{pil_image.height}).")
86
+
87
+ # Optional: Add image dimension limits if needed
88
+ # max_dim = 2048
89
+ # if pil_image.width > max_dim or pil_image.height > max_dim:
90
+ # raise gr.Error(f"Image dimensions ({pil_image.width}x{pil_image.height}) exceed the maximum allowed ({max_dim}x{max_dim}).")
91
+
92
+ return pil_image
93
+
94
+ except requests.exceptions.Timeout:
95
+ raise gr.Error(f"Request timed out after {timeout} seconds trying to download the image.")
96
+ except requests.exceptions.RequestException as e:
97
+ raise gr.Error(f"Error downloading image: {e}")
98
+ except UnidentifiedImageError:
99
+ raise gr.Error("Could not identify image file. The URL might not point to a valid image.")
100
+ except Exception as e:
101
+ print(f"An unexpected error occurred during download: {e}") # Log for server admin
102
+ raise gr.Error(f"An unexpected error occurred during image download.")
103
+
104
+
105
+ # --- Processing Function (Handles the ML part with progress) ---
106
+ def run_processing(input_pil_image: Image.Image, progress=gr.Progress(track_tqdm=True)):
107
+ """Processes the already downloaded PIL image and returns the result."""
108
+ if MODEL is None:
109
+ # Include the more specific error message if loading failed
110
+ error_msg = f"Model is not loaded. Cannot process image. Load error: {MODEL_LOAD_ERROR}" if MODEL_LOAD_ERROR else "Model is not loaded. Cannot process image."
111
+ raise gr.Error(error_msg)
112
+ if input_pil_image is None:
113
+ return None, None
114
+
115
+ start_time = time.time()
116
+ print("Starting image processing...")
117
+ progress(0, desc="Preparing for processing...")
118
+
119
+ try:
120
+ output_pil_image = process_image_from_data(
121
+ input_pil_image=input_pil_image,
122
+ model=MODEL,
123
+ device=DEVICE,
124
+ kernel_size=PATCH_KERNEL_SIZE,
125
+ stride=PATCH_STRIDE,
126
+ use_tqdm=True
127
+ )
128
+ if output_pil_image is None:
129
+ raise gr.Error("Image processing failed internally. Check server logs.")
130
+
131
+ # Save the processed image with a unique filename
132
+ filename = f"processed_{uuid.uuid4().hex}.jpg"
133
+ save_path = TEMP_DIR / filename
134
+ output_pil_image.save(save_path)
135
+ print(f"Saved processed image to: {save_path}")
136
+
137
+ except Exception as e:
138
+ print(f"Exception during processing: {e}")
139
+ import traceback
140
+ traceback.print_exc()
141
+ raise gr.Error(f"An error occurred during model processing: {e}")
142
+
143
+ end_time = time.time()
144
+ processing_time = end_time - start_time
145
+ print(f"Processing time: {processing_time:.2f} seconds")
146
+
147
+ return output_pil_image, str(save_path)
148
+
149
+ # --- Wrapper Function for Button Click ---
150
+ def download_and_clear(url):
151
+ """Calls download, returns (image, None) on success, raises gr.Error on failure."""
152
+ print("Download and clear triggered.") # Debug print
153
+ try:
154
+ img = download_image_for_gradio(url)
155
+ # Return the downloaded image for the original slot, and None for the processed slot
156
+ return img, None
157
+ except gr.Error as e:
158
+ # If download fails, raise the Gradio error again
159
+ # Gradio will handle clearing outputs and displaying the error message
160
+ print(f"Download failed: {e}") # Debug print
161
+ raise e
162
+ except Exception as e:
163
+ # Catch other unexpected errors during the wrapper logic
164
+ print(f"Unexpected error in download_and_clear: {e}") # Debug print
165
+ import traceback
166
+ traceback.print_exc()
167
+ # Raise a Gradio error so the user sees something meaningful
168
+ raise gr.Error(f"An unexpected error occurred preparing the image: {e}")
169
+
170
+
171
+ # --- Gradio Blocks Interface Definition ---
172
+ print("Setting up Gradio Blocks interface...")
173
+
174
+ with gr.Blocks() as demo:
175
+ saved_image_path = gr.State()
176
+ gr.Markdown(f"""
177
+ # Watermark Remover
178
+
179
+ **Disclaimer: This project was made to showcase my Deep Learning skills with no intention to cause harm to any business or infringe on any IP and it will be decommissioned as soon as I get a decent job offer.
180
+ If you still have any issue, please reach out to me at [[email protected]](mailto:[email protected]).**
181
+ """)
182
+
183
+ gr.Markdown(f"""
184
+ This is a demo of a DL model which takes in URL of an image with watermark and gives out the image with its watermark removed.
185
+ The image is broken into overlapping patches of 256x256 pixels with a stride of 64 before feeding them to the model,
186
+ the model infers on each patch separately and then they are stitched together to form the whole image again to give the final output.
187
+ """)
188
+
189
+ if MODEL_LOAD_ERROR:
190
+ gr.Markdown(f"<font color='red'>**Model Loading Error:** {MODEL_LOAD_ERROR}</font>")
191
+
192
+ with gr.Row():
193
+ with gr.Column(scale=1):
194
+ image_url = gr.Textbox(
195
+ label="Image URL",
196
+ placeholder="Paste image URL (https://www.shutterstock.com/shutterstock/photos/...)",
197
+ info="Enter the direct link to a publicly accessible image file (jpg, png, etc.)."
198
+ )
199
+ submit_button = gr.Button("Process Image", variant="primary")
200
+ gr.Markdown("---") # Separator
201
+ gr.Examples(
202
+ examples=[
203
+ ["https://www.shutterstock.com/shutterstock/photos/2429728793/display_1500/stock-photo-monkey-funny-open-mouth-and-showing-teeth-crazy-nature-angry-short-hair-brown-grand-bassin-2429728793.jpg"],
204
+ ["https://www.shutterstock.com/shutterstock/photos/1905929728/display_1500/stock-photo-skeptic-surprised-cat-thinking-dont-know-what-to-do-big-eyes-closeup-tabby-cat-look-side-dont-1905929728.jpg"],
205
+ ["https://www.shutterstock.com/shutterstock/photos/2501926843/display_1500/stock-photo-brunette-woman-laying-on-couch-cuddling-light-brown-dog-and-brown-tabby-cat-happy-2501926843.jpg"]
206
+ ],
207
+ inputs=image_url # Link examples to the URL textbox
208
+ )
209
+
210
+ with gr.Column(scale=2):
211
+ with gr.Row():
212
+ original_image = gr.Image(label="Original Input Image", type="pil", interactive=False)
213
+ processed_image = gr.Image(label="Processed Output Image", type="pil", interactive=False)
214
+ download_button = gr.DownloadButton("Download Processed Image", visible=True, variant="secondary")
215
+ gr.Markdown("---") # Separator
216
+ gr.Markdown(f"""
217
+ **A bit about the model:** In this project, I have trained a GAN network,
218
+ with the Generator being inspired from Pix2Pix and Pix2PixHD architectures and the Discriminator is very similar to PatchGAN in Pix2Pix and CycleGAN.
219
+ For the loss, I have also added Perceptual Loss using VGG like in Pix2PixHD and SRGAN papers apart from the L1 and BCE loss.
220
+ """)
221
+
222
+ gr.Markdown(f"""If you liked this project, you can find my CV [here]() or reach me out at [[email protected]](mailto:[email protected]).""")
223
+
224
+
225
+ # --- Event Handling Logic ---
226
+
227
+ # 1. When Button is clicked:
228
+ # - Input: URL from textbox
229
+ # - Action: Call 'download_and_clear' wrapper function
230
+ # - Output: Update 'original_image' with downloaded image, clear 'processed_image' by returning None for it.
231
+ submit_button.click(
232
+ fn=download_and_clear, # <--- CORRECTED: Use the wrapper function
233
+ inputs=image_url,
234
+ outputs=[original_image, processed_image] # Target both outputs
235
+ )
236
+
237
+ # 2. When 'original_image' component *changes* (i.e., after successful download):
238
+ # - Input: The PIL image data from 'original_image'
239
+ # - Action: Call 'run_processing' (this function has the progress bar)
240
+ # - Output: Update the 'processed_image' component.
241
+ original_image.change(
242
+ fn=run_processing,
243
+ inputs=original_image,
244
+ outputs=[processed_image, saved_image_path]
245
+ # concurrency_limit=1 # Optional: Prevent multiple simultaneous processing runs if needed
246
+ )
247
+
248
+ processed_image.change(
249
+ fn=lambda img: gr.update(variant="primary" if img is not None else "secondary"),
250
+ inputs=processed_image,
251
+ outputs=download_button
252
+ )
253
+
254
+ download_button.click(
255
+ fn=lambda path: path if path else None,
256
+ inputs=saved_image_path,
257
+ outputs=download_button
258
+ )
259
+
260
+
261
+
262
+ # --- Launch the Application ---
263
+ if __name__ == "__main__":
264
+ print("Launching Gradio Blocks interface...")
265
+ # Set queue=True for better handling under load, especially with long-running processing
266
+ # demo.queue() # Consider adding this
267
+ demo.launch(share=False, server_name="0.0.0.0", show_api=False)
268
+
generator_model.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class ConvBlock(nn.Module):
6
+ def __init__(self, in_channels, out_channels, use_act=True, **kwargs):
7
+ super().__init__()
8
+ self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=False, padding_mode="reflect")
9
+ self.bn = nn.BatchNorm2d(out_channels)
10
+ self.act = nn.ReLU(inplace=True) if use_act else nn.Identity()
11
+
12
+ def forward(self, x):
13
+ return self.act(self.bn(self.cnn(x)))
14
+
15
+ class ResidualBlock(nn.Module):
16
+ def __init__(self, in_channels):
17
+ super().__init__()
18
+ self.survival_prob = 0.8
19
+ self.block1 = ConvBlock(
20
+ in_channels,
21
+ in_channels,
22
+ kernel_size=3,
23
+ stride=1,
24
+ padding=1,
25
+ )
26
+ self.block2 = ConvBlock(
27
+ in_channels,
28
+ in_channels,
29
+ kernel_size=3,
30
+ stride=1,
31
+ padding=1,
32
+ use_act=True,
33
+ )
34
+
35
+ def stochastic_depth(self, x):
36
+ if not self.training:
37
+ return x
38
+ binary_tensor = torch.rand(x.shape[0], 1, 1, 1, device=x.device) < self.survival_prob
39
+ return torch.div(x, self.survival_prob) * binary_tensor
40
+
41
+ def forward(self, x):
42
+ out = self.block1(x)
43
+ out = self.block2(out)
44
+ return self.stochastic_depth(out) + x
45
+
46
+ class Block(nn.Module):
47
+ def __init__(self, in_channels, out_channels, stride=2, act="relu"):
48
+ super().__init__()
49
+ self.conv = nn.Sequential(
50
+ nn.Conv2d(in_channels, out_channels, 3, stride, 1, bias=False, padding_mode="reflect"),
51
+ nn.BatchNorm2d(out_channels),
52
+ nn.ReLU(inplace=True) if act == "relu" else nn.LeakyReLU(0.2, inplace=True),
53
+ )
54
+
55
+ def forward(self, x):
56
+ return self.conv(x)
57
+
58
+ class Generator(nn.Module):
59
+ def __init__(self, in_channels=3, features=64, num_residuals=9):
60
+ super().__init__()
61
+ self.initial_down = nn.Sequential(
62
+ nn.Conv2d(in_channels, features, 7, 1, 3, bias=True, padding_mode="reflect"),
63
+ nn.ReLU(inplace=True),
64
+ )
65
+ self.down1 = Block(features, features*2, act="relu")
66
+ self.down2 = Block(features*2, features*4, act="relu")
67
+ self.down3 = Block(features*4, features*8, act="relu")
68
+ self.down4 = Block(features*8, features*16, act="relu")
69
+ self.residuals = nn.Sequential(*[ResidualBlock(features*16) for _ in range(num_residuals)])
70
+ self.up1 = Block(features*16, features*8, stride=1, act="relu")
71
+ self.up2 = Block(features*8*2, features*4, stride=1, act="relu" )
72
+ self.up3 = Block(features*4*2, features*2, stride=1, act="relu")
73
+ self.up4 = Block(features*2*2, features, stride=1, act="relu")
74
+ self.final_conv = nn.Sequential(
75
+ Block(features*2, features, stride=1, act="relu"),
76
+ Block(features, features, stride=1, act="relu"),
77
+ nn.Conv2d(features, in_channels, 7,1,3, padding_mode="reflect"),
78
+ nn.Tanh(),
79
+ )
80
+
81
+ def forward(self, x):
82
+ d1 = self.initial_down(x)
83
+ d2 = self.down1(d1)
84
+ d3 = self.down2(d2)
85
+ d4 = self.down3(d3)
86
+ d5 = self.down4(d4)
87
+ residuals = self.residuals(d5) + d5
88
+ u1 = self.up1(F.interpolate(residuals, scale_factor=2, mode="nearest"))
89
+ u2 = self.up2(F.interpolate(torch.cat([u1, d4], dim=1), scale_factor=2, mode="nearest"))
90
+ u3 = self.up3(F.interpolate(torch.cat([u2, d3], dim=1), scale_factor=2, mode="nearest"))
91
+ u4 = self.up4(F.interpolate(torch.cat([u3, d2], dim=1), scale_factor=2, mode="nearest"))
92
+ return self.final_conv(torch.cat([u4, d1], dim=1))
93
+
94
+ def test():
95
+ x = torch.randn((1, 3, 256, 256))
96
+ model = Generator(in_channels=3, features=64)
97
+ preds = model(x)
98
+ print(preds.shape)
99
+
100
+ #test()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ gradio~=5.25.2
2
+ requests~=2.32.3
3
+ pillow~=11.1.0
4
+ numpy~=1.26.4
5
+ albumentations~=2.0.5
6
+ tqdm~=4.67.1
7
+ torch~=2.6.0
8
+ huggingface_hub~=0.30.2
run_on_patches_online.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from PIL import Image, UnidentifiedImageError
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import albumentations as A
7
+ from albumentations.pytorch import ToTensorV2
8
+ from collections import OrderedDict
9
+ from tqdm import tqdm
10
+ import requests
11
+ import io
12
+ from generator_model import Generator
13
+
14
+ # --- Constants ---
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
+ CHECKPOINT_GEN = "gen_epoch_42.pth.tar" # Keep your checkpoint name
17
+ PATCH_KERNEL_SIZE = 256
18
+ PATCH_STRIDE = 64
19
+ # DEFAULT_INPUT_DIR = "test/inputs" # No longer needed for Gradio URL input
20
+ # DEFAULT_OUTPUT_DIR = "test/outputs" # Output handled by Gradio
21
+ SUPPORTED_EXTENSIONS = ('.png', '.jpg', '.jpeg', '.bmp', '.tif', '.tiff') # Still useful for local testing if needed
22
+
23
+ test_transform = A.Compose([
24
+ A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0),
25
+ ToTensorV2()
26
+ ])
27
+
28
+ def load_model(checkpoint_path: str, device: str) -> Generator:
29
+ print(f"Loading model from: {checkpoint_path} onto device: {device}")
30
+ model = Generator(in_channels=3, features=64).to(device)
31
+ checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
32
+ new_state_dict = OrderedDict()
33
+ has_module_prefix = any(k.startswith("module.") for k in checkpoint["state_dict"])
34
+ for k, v in checkpoint["state_dict"].items():
35
+ name = k.replace("module.", "") if has_module_prefix else k
36
+ new_state_dict[name] = v
37
+ model.load_state_dict(new_state_dict)
38
+ model.eval() # Set model to evaluation mode
39
+ print("Model loaded successfully.")
40
+ return model
41
+
42
+ def calculate_padding(img_h: int, img_w: int, kernel_size: int, stride: int) -> tuple[int, int]:
43
+ pad_h = kernel_size - img_h if img_h < kernel_size else (stride - (img_h - kernel_size) % stride) % stride
44
+ pad_w = kernel_size - img_w if img_w < kernel_size else (stride - (img_w - kernel_size) % stride) % stride
45
+ return pad_h, pad_w
46
+
47
+ def download_image(url: str, timeout: int = 15) -> Image.Image | None:
48
+ """Downloads an image from a URL and returns it as a PIL Image object."""
49
+ print(f"Attempting to download image from: {url}")
50
+ try:
51
+ headers = {'User-Agent': 'Gradio-Image-Processor/1.0'} # Be a good net citizen
52
+ response = requests.get(url, stream=True, timeout=timeout, headers=headers)
53
+ response.raise_for_status() # Raise an exception for bad status codes (4xx or 5xx)
54
+
55
+ content_type = response.headers.get('Content-Type', '').lower()
56
+ if not content_type.startswith('image/'):
57
+ print(f"Error: URL content type ({content_type}) is not an image.")
58
+ return None
59
+
60
+ image_bytes = response.content
61
+ pil_image = Image.open(io.BytesIO(image_bytes))
62
+ pil_image = pil_image.convert('RGB') # Ensure image is in RGB format
63
+ print(f"Image downloaded successfully ({pil_image.width}x{pil_image.height}).")
64
+ return pil_image
65
+
66
+ except requests.exceptions.Timeout:
67
+ print(f"Error: Request timed out after {timeout} seconds.")
68
+ return None
69
+ except requests.exceptions.RequestException as e:
70
+ print(f"Error downloading image: {e}")
71
+ return None
72
+ except UnidentifiedImageError:
73
+ print("Error: Could not identify image file. The URL might not point to a valid image.")
74
+ return None
75
+ except Exception as e:
76
+ print(f"An unexpected error occurred during download: {e}")
77
+ return None
78
+
79
+ def process_image_from_data(
80
+ input_pil_image: Image.Image,
81
+ model: Generator,
82
+ device: str,
83
+ kernel_size: int,
84
+ stride: int,
85
+ use_tqdm: bool = True # Optional: Control progress bar visibility
86
+ ) -> Image.Image | None:
87
+ """
88
+ Processes an input PIL image using the patch-based method and returns the output PIL image.
89
+ Returns None if an error occurs during processing.
90
+ """
91
+ print(f"\nProcessing image data...")
92
+ try:
93
+ image_np = np.array(input_pil_image) # Convert PIL Image to NumPy array
94
+ H, W, _ = image_np.shape
95
+ print(f" Input dimensions: {W}x{H}")
96
+
97
+ # Apply transformations
98
+ transformed = test_transform(image=image_np)
99
+ input_tensor = transformed['image'].to(device) # Shape: (C, H, W)
100
+ C = input_tensor.shape[0]
101
+
102
+ # Calculate and apply padding
103
+ pad_h, pad_w = calculate_padding(H, W, kernel_size, stride)
104
+ print(f" Calculated padding (H, W): ({pad_h}, {pad_w})")
105
+ padded_tensor = F.pad(input_tensor.unsqueeze(0), (0, pad_w, 0, pad_h), mode='reflect').squeeze(0)
106
+ _, H_pad, W_pad = padded_tensor.shape
107
+ print(f" Padded dimensions: {W_pad}x{H_pad}")
108
+
109
+ # Extract patches
110
+ patches = padded_tensor.unfold(1, kernel_size, stride).unfold(2, kernel_size, stride)
111
+ num_patches_h = patches.shape[1]
112
+ num_patches_w = patches.shape[2]
113
+ num_patches_total = num_patches_h * num_patches_w
114
+ print(f" Extracted {num_patches_total} patches ({num_patches_h} H x {num_patches_w} W)")
115
+
116
+ patches = patches.contiguous().view(C, -1, kernel_size, kernel_size)
117
+ # Permute to (num_patches_total, C, kernel_size, kernel_size)
118
+ patches = patches.permute(1, 0, 2, 3).contiguous()
119
+
120
+ output_patches = []
121
+ # Set up tqdm iterator if enabled
122
+ patch_iterator = tqdm(patches, total=num_patches_total, desc=" Inferring patches", unit="patch", leave=False, disable=not use_tqdm)
123
+
124
+ # --- Inference Loop ---
125
+ with torch.no_grad():
126
+ for patch in patch_iterator:
127
+ # Add batch dimension, run model, remove batch dimension
128
+ output_patch = model(patch.unsqueeze(0)).squeeze(0)
129
+ # Move to CPU immediately to save GPU memory during inference loop
130
+ output_patches.append(output_patch.cpu())
131
+
132
+ # Stack output patches back together
133
+ # If GPU memory allows, move back for reconstruction, otherwise keep on CPU
134
+ # Let's try moving back to device for faster reconstruction if possible
135
+ try:
136
+ output_patches = torch.stack(output_patches).to(device)
137
+ print(f" Output patches moved to {device} for reconstruction.")
138
+ except Exception as e: # Catch potential OOM on device
139
+ print(f" Warning: Could not move all output patches to {device} ({e}). Reconstruction might be slower on CPU.")
140
+ output_patches = torch.stack(output_patches) # Keep on CPU
141
+
142
+
143
+ # --- Reconstruction ---
144
+ # Generate 2D Hann window for blending
145
+ window_1d = torch.hann_window(kernel_size, periodic=False, device=device) # periodic=False is common
146
+ window_2d = torch.outer(window_1d, window_1d)
147
+ window_2d = window_2d.unsqueeze(0).to(device) # Add channel dim and ensure on device
148
+
149
+ # Initialize output tensor and weight tensor (for weighted averaging)
150
+ output_tensor = torch.zeros((C, H_pad, W_pad), device=device, dtype=output_patches.dtype)
151
+ weight_tensor = torch.zeros((C, H_pad, W_pad), device=device, dtype=window_2d.dtype)
152
+
153
+ patch_idx = 0
154
+ reconstruct_iterator = tqdm(total=num_patches_total, desc=" Reconstructing", unit="patch", leave=False, disable=not use_tqdm)
155
+
156
+ for i in range(num_patches_h):
157
+ for j in range(num_patches_w):
158
+ h_start = i * stride
159
+ w_start = j * stride
160
+ h_end = h_start + kernel_size
161
+ w_end = w_start + kernel_size
162
+
163
+ # Get current patch (ensure it's on the correct device)
164
+ current_patch = output_patches[patch_idx].to(device)
165
+ weighted_patch = current_patch * window_2d # Apply window
166
+
167
+ # Add weighted patch to output tensor
168
+ output_tensor[:, h_start:h_end, w_start:w_end] += weighted_patch
169
+ # Accumulate weights
170
+ weight_tensor[:, h_start:h_end, w_start:w_end] += window_2d
171
+
172
+ patch_idx += 1
173
+ reconstruct_iterator.update(1)
174
+
175
+ reconstruct_iterator.close() # Close the inner tqdm bar
176
+
177
+ # Perform weighted averaging - clamp weights to avoid division by zero
178
+ output_averaged = output_tensor / weight_tensor.clamp(min=1e-6)
179
+
180
+ # Crop to original dimensions
181
+ output_cropped = output_averaged[:, :H, :W]
182
+ print(f" Final output dimensions: {output_cropped.shape[2]}x{output_cropped.shape[1]}")
183
+
184
+ # --- Convert to Output Format ---
185
+ # Permute C, H, W -> H, W, C ; Move to CPU ; Convert to NumPy
186
+ output_numpy = output_cropped.permute(1, 2, 0).cpu().numpy()
187
+
188
+ # Denormalize: Assuming input was normalized to [-1, 1]
189
+ output_numpy = (output_numpy * 0.5 + 0.5) * 255.0
190
+
191
+ # Clip values to [0, 255] and convert to uint8
192
+ output_numpy = output_numpy.clip(0, 255).astype(np.uint8)
193
+
194
+ # Convert NumPy array back to PIL Image
195
+ output_image = Image.fromarray(output_numpy)
196
+
197
+ print(" Image processing complete.")
198
+ return output_image
199
+
200
+ except Exception as e:
201
+ print(f"Error during image processing: {e}")
202
+ import traceback
203
+ traceback.print_exc() # Print detailed traceback for debugging
204
+ return None
205
+
206
+ if __name__ == "__main__":
207
+ print("--- Testing Phase 1 Refactoring ---")
208
+ print(f"Using device: {DEVICE}")
209
+ print(f"Using patch kernel size: {PATCH_KERNEL_SIZE}")
210
+ print(f"Using patch stride: {PATCH_STRIDE}")
211
+ print(f"Using model checkpoint: {CHECKPOINT_GEN}")
212
+
213
+ # 1. Load the model (as it would be done globally in Gradio app)
214
+ try:
215
+ model = load_model(CHECKPOINT_GEN, DEVICE)
216
+ except Exception as e:
217
+ print(f"Failed to load model. Exiting test. Error: {e}")
218
+ exit()
219
+
220
+
221
+ # 2. Test URL download
222
+ # Replace with a valid image URL for testing
223
+ # test_url = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg"
224
+ test_url = "https://www.shutterstock.com/shutterstock/photos/2501926843/display_1500/stock-photo-brunette-woman-laying-on-couch-cuddling-light-brown-dog-and-brown-tabby-cat-happy-2501926843.jpg" # A smaller known image
225
+ input_pil = download_image(test_url)
226
+
227
+ if input_pil:
228
+ print(f"\nDownloaded image type: {type(input_pil)}, size: {input_pil.size}")
229
+
230
+ # 3. Test processing the downloaded image
231
+ output_pil = process_image_from_data(
232
+ input_pil_image=input_pil,
233
+ model=model,
234
+ device=DEVICE,
235
+ kernel_size=PATCH_KERNEL_SIZE,
236
+ stride=PATCH_STRIDE,
237
+ use_tqdm=True # Show progress bars during test
238
+ )
239
+
240
+ if output_pil:
241
+ print(f"\nProcessed image type: {type(output_pil)}, size: {output_pil.size}")
242
+ # Save the output locally for verification during testing
243
+ try:
244
+ os.makedirs("test_outputs", exist_ok=True)
245
+ output_filename = "test_output_" + os.path.basename(test_url).split('?')[0] # Basic filename extraction
246
+ if not output_filename.lower().endswith(SUPPORTED_EXTENSIONS):
247
+ output_filename += ".png" # Ensure it has an extension
248
+ save_path = os.path.join("test_outputs", output_filename)
249
+ output_pil.save(save_path)
250
+ print(f"Saved test output to: {save_path}")
251
+ except Exception as e:
252
+ print(f"Error saving test output: {e}")
253
+ else:
254
+ print("\nImage processing failed.")
255
+ else:
256
+ print("\nImage download failed.")
257
+
258
+ print("\n--- Phase 1 Testing Complete ---")