Bismay commited on
Commit
475e066
·
0 Parent(s):

Initial commit

Browse files
.gitignore ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ env/
8
+ build/
9
+ develop-eggs/
10
+ dist/
11
+ downloads/
12
+ eggs/
13
+ .eggs/
14
+ lib/
15
+ lib64/
16
+ parts/
17
+ sdist/
18
+ var/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+ .env
23
+ .venv
24
+ venv/
25
+ ENV/
26
+
27
+ # IDE
28
+ .idea/
29
+ .vscode/
30
+ *.swp
31
+ *.swo
32
+ .DS_Store
33
+
34
+ # Binary files and assets
35
+ parser/u2net_cloth_seg/assets/*.png
36
+ upscaler/real_esrgan/assets/*.png
37
+ upscaler/real_esrgan/assets/*.jpg
38
+ upscaler/real_esrgan/inputs/*.png
39
+ upscaler/real_esrgan/inputs/video/*.mp4
40
+ upscaler/real_esrgan/tests/data/gt.lmdb/
41
+ upscaler/real_esrgan/tests/data/gt/*.png
42
+
43
+ # Models
44
+ models/
45
+ *.pth
46
+ *.ckpt
47
+ *.safetensors
48
+
49
+ # Logs
50
+ *.log
51
+ logs/
52
+
53
+ # Temporary files
54
+ *.tmp
55
+ *.temp
56
+ *.bak
57
+ *.backup
58
+
59
+ # System files
60
+ .DS_Store
61
+ Thumbs.db
62
+
63
+ # Distribution / packaging
64
+ .Python
65
+ build/
66
+ develop-eggs/
67
+ dist/
68
+ downloads/
69
+ eggs/
70
+ .eggs/
71
+ lib/
72
+ lib64/
73
+ parts/
74
+ sdist/
75
+ var/
76
+ wheels/
77
+ pip-wheel-metadata/
78
+ share/python-wheels/
79
+ *.egg-info/
80
+ .installed.cfg
81
+ *.egg
82
+ MANIFEST
83
+
84
+ # PyInstaller
85
+ # Usually these files are written by a python script from a template
86
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
87
+ *.manifest
88
+ *.spec
89
+
90
+ # Installer logs
91
+ pip-log.txt
92
+ pip-delete-this-directory.txt
93
+
94
+ # Unit test / coverage reports
95
+ htmlcov/
96
+ .tox/
97
+ .nox/
98
+ .coverage
99
+ .coverage.*
100
+ .cache
101
+ nosetests.xml
102
+ coverage.xml
103
+ *.cover
104
+ *.py,cover
105
+ .hypothesis/
106
+ .pytest_cache/
107
+
108
+ # Translations
109
+ *.mo
110
+ *.pot
111
+
112
+ # Django stuff:
113
+ *.log
114
+ local_settings.py
115
+ db.sqlite3
116
+ db.sqlite3-journal
117
+
118
+ # Flask stuff:
119
+ instance/
120
+ .webassets-cache
121
+
122
+ # Scrapy stuff:
123
+ .scrapy
124
+
125
+ # Sphinx documentation
126
+ docs/_build/
127
+
128
+ # PyBuilder
129
+ target/
130
+
131
+ # Jupyter Notebook
132
+ .ipynb_checkpoints
133
+
134
+ # IPython
135
+ profile_default/
136
+ ipython_config.py
137
+
138
+ # pyenv
139
+ .python-version
140
+
141
+ # pipenv
142
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
143
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
144
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
145
+ # install all needed dependencies.
146
+ #Pipfile.lock
147
+
148
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
149
+ __pypackages__/
150
+
151
+ # Celery stuff
152
+ celerybeat-schedule
153
+ celerybeat.pid
154
+
155
+ # SageMath parsed files
156
+ *.sage.py
157
+
158
+ # Spyder project settings
159
+ .spyderproject
160
+ .spyproject
161
+
162
+ # Rope project settings
163
+ .ropeproject
164
+
165
+ # mkdocs documentation
166
+ /site
167
+
168
+ # mypy
169
+ .mypy_cache/
170
+ .dmypy.json
171
+ dmypy.json
172
+
173
+ # Pyre type checker
174
+ .pyre/
175
+
176
+ # Additional exclusions
177
+ *.swp
178
+ *.swo
Dockerfile ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Must use a Cuda version 11+
2
+ # FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
3
+ FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-runtime
4
+
5
+ WORKDIR /
6
+ COPY ./parser /parser
7
+ COPY ./configs /configs
8
+ RUN mkdir /checkpoints
9
+ # Install git
10
+ RUN apt-get update && apt-get install -y --no-install-recommends \
11
+ git \
12
+ build-essential
13
+
14
+ # Install python packages
15
+ RUN pip3 install --upgrade pip
16
+ ADD requirements.txt requirements.txt
17
+ RUN pip3 install -r requirements.txt
18
+
19
+ # Install cv2 dependencies
20
+ RUN apt-get update
21
+ RUN apt-get install ffmpeg libsm6 libxext6 -y
22
+
23
+ # We add the banana boilerplate here
24
+ ADD server.py .
25
+
26
+ # Add your model weight files
27
+ # (in this case we have a python script)
28
+ ADD download.py .
29
+
30
+ # Add your custom app code, init() and inference()
31
+ ADD app.py .
32
+
33
+ ENV PYTHONPATH "${PYTHONPATH}:/parser:/upscaler"
34
+
35
+ #Alternative to using build args, you can put your token in the next line
36
+ #ENV HF_AUTH_TOKEN={token}
37
+ RUN python3 download.py
38
+
39
+ EXPOSE 8000
40
+
41
+ CMD python3 -u server.py
LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Banana
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: ClothQuill - AI Clothing Inpainting
3
+ emoji: 👕
4
+ colorFrom: blue
5
+ colorTo: purple
6
+ sdk: gradio
7
+ sdk_version: 5.25.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # ClothQuill - AI Clothing Inpainting
13
+
14
+ This Space allows you to inpaint clothing on images using Stable Diffusion. Upload an image, provide a prompt describing the clothing you want to generate, and get multiple inpainted results.
15
+
16
+ ## How to Use
17
+
18
+ 1. Upload an image containing a person
19
+ 2. Enter a prompt describing the clothing you want to generate
20
+ 3. Click "Generate" to get multiple inpainted results
21
+ 4. Download your favorite result
22
+
23
+ ## Examples
24
+
25
+ - Prompt: "A stylish black leather jacket"
26
+ - Prompt: "A formal blue suit with white shirt"
27
+ - Prompt: "A casual red hoodie"
28
+
29
+ ## Technical Details
30
+
31
+ This Space uses:
32
+ - Stable Diffusion for inpainting
33
+ - U2NET for human parsing
34
+ - RealESRGAN for upscaling
35
+
36
+ ## License
37
+
38
+ This project is licensed under the MIT License.
app.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import autocast
3
+ from diffusers import StableDiffusionInpaintPipeline
4
+ import gradio as gr
5
+ import traceback
6
+ import base64
7
+ from io import BytesIO
8
+ import os
9
+ # import sys
10
+ import PIL
11
+ import json
12
+ import requests
13
+ import logging
14
+ import time
15
+ import warnings
16
+ import numpy as np
17
+ from PIL import Image, ImageDraw
18
+ import cv2
19
+ warnings.filterwarnings("ignore")
20
+
21
+ # sys.path.insert(1, './parser')
22
+
23
+ # from parser.schp_masker import *
24
+ from parser.segformer_parser import SegformerParser
25
+
26
+ # Configure logging
27
+ logging.basicConfig(
28
+ level=logging.INFO,
29
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
30
+ )
31
+ logger = logging.getLogger('clothquill')
32
+
33
+ # Model paths
34
+ SEGFORMER_MODEL = "mattmdjaga/segformer_b2_clothes"
35
+ STABLE_DIFFUSION_MODEL = "stabilityai/stable-diffusion-2-inpainting"
36
+
37
+ # Global variables for models
38
+ parser = None
39
+ model = None
40
+ inpainter = None
41
+ original_image = None # Store the original uploaded image
42
+
43
+ # Color mapping for different clothing parts
44
+ CLOTHING_COLORS = {
45
+ 'Background': (0, 0, 0, 0), # Transparent
46
+ 'Hat': (255, 0, 0, 128), # Red
47
+ 'Hair': (0, 255, 0, 128), # Green
48
+ 'Glove': (0, 0, 255, 128), # Blue
49
+ 'Sunglasses': (255, 255, 0, 128), # Yellow
50
+ 'Upper-clothes': (255, 0, 255, 128), # Magenta
51
+ 'Dress': (0, 255, 255, 128), # Cyan
52
+ 'Coat': (128, 0, 0, 128), # Dark Red
53
+ 'Socks': (0, 128, 0, 128), # Dark Green
54
+ 'Pants': (0, 0, 128, 128), # Dark Blue
55
+ 'Jumpsuits': (128, 128, 0, 128), # Dark Yellow
56
+ 'Scarf': (128, 0, 128, 128), # Dark Magenta
57
+ 'Skirt': (0, 128, 128, 128), # Dark Cyan
58
+ 'Face': (192, 192, 192, 128), # Light Gray
59
+ 'Left-arm': (64, 64, 64, 128), # Dark Gray
60
+ 'Right-arm': (64, 64, 64, 128), # Dark Gray
61
+ 'Left-leg': (32, 32, 32, 128), # Very Dark Gray
62
+ 'Right-leg': (32, 32, 32, 128), # Very Dark Gray
63
+ 'Left-shoe': (16, 16, 16, 128), # Almost Black
64
+ 'Right-shoe': (16, 16, 16, 128), # Almost Black
65
+ }
66
+
67
+ def get_device():
68
+ if torch.cuda.is_available():
69
+ device = "cuda"
70
+ logger.info("Using GPU")
71
+ else:
72
+ device = "cpu"
73
+ logger.info("Using CPU")
74
+ return device
75
+
76
+ def init():
77
+ global parser
78
+ global model
79
+ global inpainter
80
+
81
+ start_time = time.time()
82
+ logger.info("Starting application initialization")
83
+
84
+ try:
85
+ device = get_device()
86
+
87
+ # Check if models directory exists
88
+ if not os.path.exists("models"):
89
+ logger.info("Creating models directory...")
90
+ from download_models import download_models
91
+ download_models()
92
+
93
+ # Initialize Segformer parser
94
+ logger.info("Initializing Segformer parser...")
95
+ parser = SegformerParser(SEGFORMER_MODEL)
96
+
97
+ # Initialize Stable Diffusion model
98
+ logger.info("Initializing Stable Diffusion model...")
99
+ model = StableDiffusionInpaintPipeline.from_pretrained(
100
+ STABLE_DIFFUSION_MODEL,
101
+ safety_checker=None,
102
+ revision="fp16" if device == "cuda" else None,
103
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
104
+ ).to(device)
105
+
106
+ # Initialize inpainter
107
+ logger.info("Initializing inpainter...")
108
+ inpainter = ClothingInpainter(model=model, parser=parser)
109
+
110
+ logger.info(f"Application initialized in {time.time() - start_time:.2f} seconds")
111
+ except Exception as e:
112
+ logger.error(f"Error initializing application: {str(e)}")
113
+ raise e
114
+
115
+ class ClothingInpainter:
116
+ def __init__(self, model_path=None, model=None, parser=None):
117
+ self.device = get_device()
118
+ self.last_mask = None # Store the last generated mask
119
+ self.original_image = None # Store the original image
120
+
121
+ if model_path is None and model is None:
122
+ raise ValueError('No model provided!')
123
+ if model_path is not None:
124
+ self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
125
+ model_path,
126
+ safety_checker=None,
127
+ revision="fp16" if self.device == "cuda" else None,
128
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
129
+ ).to(self.device)
130
+ else:
131
+ self.pipe = model
132
+
133
+ self.parser = parser
134
+
135
+ def make_square(self, im, min_size=256, fill_color=(0, 0, 0, 0)):
136
+ x, y = im.size
137
+ size = max(min_size, x, y)
138
+ new_im = PIL.Image.new('RGBA', (size, size), fill_color)
139
+ new_im.paste(im, (int((size - x) / 2), int((size - y) / 2)))
140
+ return new_im.convert('RGB')
141
+
142
+ def unmake_square(self, init_im, op_im, min_size=256, rs_size=512):
143
+ x, y = init_im.size
144
+ size = max(min_size, x, y)
145
+ factor = rs_size/size
146
+ return op_im.crop((int((size-x) * factor / 2), int((size-y) * factor / 2),\
147
+ int((size+x) * factor / 2), int((size+y) * factor / 2)))
148
+
149
+ def visualize_segmentation(self, image, masks, selected_parts=None):
150
+ """Visualize segmentation with colored overlays for selected parts and gray for unselected."""
151
+ # Always use original image if available
152
+ image_to_use = self.original_image if self.original_image is not None else image
153
+
154
+ # Create a copy of the original image
155
+ original_size = image_to_use.size
156
+ vis_image = image_to_use.copy().convert('RGBA')
157
+
158
+ # Create overlay at 512x512
159
+ overlay = Image.new('RGBA', (512, 512), (0, 0, 0, 0))
160
+ draw = ImageDraw.Draw(overlay)
161
+
162
+ # Draw each mask with its corresponding color
163
+ for part_name, mask in masks.items():
164
+ # Convert part name for color lookup
165
+ color_key = part_name.replace('-', ' ').title().replace(' ', '-')
166
+ is_selected = selected_parts and part_name in selected_parts
167
+
168
+ # If selected, use color (with fallback). If unselected, use faint gray
169
+ if is_selected:
170
+ color = CLOTHING_COLORS.get(color_key, (255, 0, 255, 128)) # Default to magenta if no color found
171
+ else:
172
+ color = (180, 180, 180, 80) # Faint gray for unselected
173
+
174
+ mask_array = np.array(mask)
175
+ coords = np.where(mask_array > 0)
176
+ for y, x in zip(coords[0], coords[1]):
177
+ draw.point((x, y), fill=color)
178
+
179
+ # Resize overlay to match original image size
180
+ overlay = overlay.resize(original_size, Image.Resampling.LANCZOS)
181
+
182
+ # Composite the overlay onto the original image
183
+ vis_image = Image.alpha_composite(vis_image, overlay)
184
+ return vis_image
185
+
186
+ def inpaint(self, prompt, init_image, selected_parts=None, dilation_iterations=2) -> dict:
187
+ image = self.make_square(init_image).resize((512,512))
188
+
189
+ if self.parser is not None:
190
+ masks = self.parser.get_all_masks(image)
191
+ masks = {k: v.resize((512,512)) for k, v in masks.items()}
192
+ else:
193
+ raise ValueError('Image Parser is Missing')
194
+
195
+ logger.info(f'[generated required mask(s) at {time.time()}]')
196
+
197
+ # Create combined mask for selected parts
198
+ if selected_parts:
199
+ combined_mask = Image.new('L', (512, 512), 0)
200
+ for part in selected_parts:
201
+ if part in masks:
202
+ mask_array = np.array(masks[part])
203
+ kernel = np.ones((5,5), np.uint8)
204
+ dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
205
+ dilated_mask = Image.fromarray(dilated_mask)
206
+ combined_mask = Image.composite(
207
+ Image.new('L', (512, 512), 255),
208
+ combined_mask,
209
+ dilated_mask
210
+ )
211
+ else:
212
+ # If no parts selected, use all clothing parts
213
+ combined_mask = Image.new('L', (512, 512), 0)
214
+ for part, mask in masks.items():
215
+ if part in ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']:
216
+ mask_array = np.array(mask)
217
+ kernel = np.ones((5,5), np.uint8)
218
+ dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
219
+ dilated_mask = Image.fromarray(dilated_mask)
220
+ combined_mask = Image.composite(
221
+ Image.new('L', (512, 512), 255),
222
+ combined_mask,
223
+ dilated_mask
224
+ )
225
+
226
+ # Run the model
227
+ guidance_scale=7.5
228
+ num_samples = 3
229
+ with autocast("cuda"), torch.inference_mode():
230
+ images = self.pipe(
231
+ num_inference_steps = 50,
232
+ prompt=prompt['pos'],
233
+ image=image,
234
+ mask_image=combined_mask,
235
+ guidance_scale=guidance_scale,
236
+ num_images_per_prompt=num_samples,
237
+ ).images
238
+
239
+ images_output = []
240
+ for img in images:
241
+ ch = PIL.Image.composite(img, image, combined_mask)
242
+ fin_img = self.unmake_square(init_image, ch)
243
+ images_output.append(fin_img)
244
+
245
+ return images_output
246
+
247
+ def process_segmentation(image, dilation_iterations=2):
248
+ try:
249
+ if image is None:
250
+ raise gr.Error("Please upload an image")
251
+
252
+ # Store original image
253
+ inpainter.original_image = image.copy()
254
+
255
+ # Create a processing copy at 512x512
256
+ proc_image = image.resize((512, 512), Image.Resampling.LANCZOS)
257
+
258
+ # Get the main mask
259
+ all_masks = inpainter.parser.get_all_masks(proc_image)
260
+ if not all_masks:
261
+ logger.error("No clothing detected in the image")
262
+ raise gr.Error("No clothing detected in the image. Please try a different image.")
263
+ inpainter.last_mask = all_masks
264
+ # Only show main clothing parts for selection
265
+ main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']
266
+ masks = {k: v for k, v in all_masks.items() if k in main_parts}
267
+ vis_image = inpainter.visualize_segmentation(image, masks, selected_parts=None)
268
+ detected_parts = [k for k in masks.keys()]
269
+ return vis_image, gr.update(choices=detected_parts, value=[])
270
+ except gr.Error as e:
271
+ raise e
272
+ except Exception as e:
273
+ logger.error(f"Error processing segmentation: {str(e)}")
274
+ raise gr.Error("Error processing the image. Please try a different image.")
275
+
276
+ def update_dilation(image, selected_parts, dilation_iterations):
277
+ try:
278
+ if image is None or inpainter.last_mask is None:
279
+ return image
280
+ # Redilate all stored masks
281
+ main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']
282
+ masks = {}
283
+ for part in main_parts:
284
+ if part in inpainter.last_mask:
285
+ mask_array = np.array(inpainter.last_mask[part])
286
+ kernel = np.ones((5,5), np.uint8)
287
+ dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
288
+ masks[part] = Image.fromarray(dilated_mask)
289
+ # Use original image for visualization
290
+ vis_image = inpainter.visualize_segmentation(inpainter.original_image, masks, selected_parts=selected_parts)
291
+ return vis_image
292
+ except Exception as e:
293
+ logger.error(f"Error updating dilation: {str(e)}")
294
+ return image
295
+
296
+ def process_image(prompt, image, selected_parts, dilation_iterations):
297
+ start_time = time.time()
298
+ logger.info(f"Processing new request - Prompt: {prompt}, Image size: {image.size if image else 'None'}")
299
+
300
+ try:
301
+ if image is None:
302
+ logger.error("No image provided")
303
+ raise gr.Error("Please upload an image")
304
+ if not prompt:
305
+ logger.error("No prompt provided")
306
+ raise gr.Error("Please enter a prompt")
307
+ if not selected_parts:
308
+ logger.error("No parts selected")
309
+ raise gr.Error("Please select at least one clothing part to modify")
310
+
311
+ prompt_dict = {'pos': prompt}
312
+ logger.info("Starting inpainting process")
313
+
314
+ # Generate inpainted images
315
+ # Convert selected_parts to lowercase/dash format
316
+ selected_parts = [p.lower() for p in selected_parts]
317
+ images = inpainter.inpaint(prompt_dict, image, selected_parts, dilation_iterations)
318
+
319
+ if not images:
320
+ logger.error("Inpainting failed to produce results")
321
+ raise gr.Error("Failed to generate images. Please try again.")
322
+
323
+ logger.info(f"Request processed in {time.time() - start_time:.2f} seconds")
324
+ return images
325
+ except Exception as e:
326
+ logger.error(f"Error processing image: {str(e)}")
327
+ raise gr.Error(f"Error processing image: {str(e)}")
328
+
329
+ def update_selected_parts(image, selected_parts, dilation_iterations):
330
+ try:
331
+ if image is None or inpainter.last_mask is None:
332
+ return image
333
+ main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']
334
+ masks = {}
335
+ for part in main_parts:
336
+ if part in inpainter.last_mask:
337
+ mask_array = np.array(inpainter.last_mask[part])
338
+ kernel = np.ones((5,5), np.uint8)
339
+ dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
340
+ masks[part] = Image.fromarray(dilated_mask)
341
+ # Lowercase the selected_parts for comparison
342
+ selected_parts = [p.lower() for p in selected_parts] if selected_parts else []
343
+ # Use original image for visualization
344
+ vis_image = inpainter.visualize_segmentation(inpainter.original_image, masks, selected_parts=selected_parts)
345
+ return vis_image
346
+ except Exception as e:
347
+ logger.error(f"Error updating selected parts: {str(e)}")
348
+ return image
349
+
350
+ # Initialize the model
351
+ init()
352
+
353
+ # Create Gradio interface
354
+ with gr.Blocks(title="ClothQuill - AI Clothing Inpainting") as demo:
355
+ gr.Markdown("# ClothQuill - AI Clothing Inpainting")
356
+ gr.Markdown("Upload an image to see segmented clothing parts, then select parts to modify and describe your changes")
357
+
358
+ with gr.Row():
359
+ with gr.Column():
360
+ input_image = gr.Image(
361
+ type="pil",
362
+ label="Upload Image",
363
+ scale=1, # This ensures the image maintains its aspect ratio
364
+ height=None # Allow dynamic height based on content
365
+ )
366
+ dilation_slider = gr.Slider(
367
+ minimum=0,
368
+ maximum=5,
369
+ value=2,
370
+ step=1,
371
+ label="Mask Dilation",
372
+ info="Adjust the mask dilation to control the area of modification"
373
+ )
374
+ selected_parts = gr.CheckboxGroup(
375
+ choices=[],
376
+ label="Select parts to modify",
377
+ value=[]
378
+ )
379
+ prompt = gr.Textbox(
380
+ label="Describe the clothing you want to generate",
381
+ placeholder="e.g., A stylish black leather jacket"
382
+ )
383
+ generate_btn = gr.Button("Generate")
384
+
385
+ with gr.Column():
386
+ gallery = gr.Gallery(
387
+ label="Generated Results",
388
+ show_label=False,
389
+ columns=2,
390
+ height=None, # Allow dynamic height
391
+ object_fit="contain" # Maintain aspect ratio
392
+ )
393
+
394
+ # Add event handler for image upload
395
+ input_image.upload(
396
+ fn=process_segmentation,
397
+ inputs=[input_image, dilation_slider],
398
+ outputs=[input_image, selected_parts]
399
+ )
400
+
401
+ # Add event handler for dilation changes
402
+ dilation_slider.change(
403
+ fn=update_dilation,
404
+ inputs=[input_image, selected_parts,dilation_slider],
405
+ outputs=input_image
406
+ )
407
+
408
+ # Add event handler for generation
409
+ generate_btn.click(
410
+ fn=process_image,
411
+ inputs=[prompt, input_image, selected_parts, dilation_slider],
412
+ outputs=gallery
413
+ )
414
+
415
+ # Add event handler for part selection changes
416
+ selected_parts.change(
417
+ fn=update_selected_parts,
418
+ inputs=[input_image, selected_parts, dilation_slider],
419
+ outputs=input_image
420
+ )
421
+
422
+ if __name__ == "__main__":
423
+ demo.launch(share=True)
424
+
425
+
426
+
427
+
428
+
429
+
colab_demo.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import autocast
3
+ from diffusers import StableDiffusionInpaintPipeline
4
+ import gradio as gr
5
+ import traceback
6
+ import base64
7
+ from io import BytesIO
8
+ import os
9
+ import PIL
10
+ import json
11
+ import requests
12
+ import logging
13
+ import time
14
+ import warnings
15
+ warnings.filterwarnings("ignore")
16
+
17
+ # Configure logging
18
+ logging.basicConfig(
19
+ level=logging.INFO,
20
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
21
+ )
22
+ logger = logging.getLogger('looks.studio')
23
+
24
+ # Model paths
25
+ SEGFORMER_MODEL = "mattmdjaga/segformer_b2_clothes"
26
+ STABLE_DIFFUSION_MODEL = "stabilityai/stable-diffusion-2-inpainting"
27
+
28
+ # Global variables for models
29
+ parser = None
30
+ model = None
31
+ inpainter = None
32
+
33
+ def get_device():
34
+ if torch.cuda.is_available():
35
+ device = "cuda"
36
+ logger.info("Using GPU")
37
+ else:
38
+ device = "cpu"
39
+ logger.info("Using CPU")
40
+ return device
41
+
42
+ def init():
43
+ global parser
44
+ global model
45
+ global inpainter
46
+
47
+ start_time = time.time()
48
+ logger.info("Starting application initialization")
49
+
50
+ try:
51
+ device = get_device()
52
+
53
+ # Initialize Segformer parser
54
+ logger.info("Initializing Segformer parser...")
55
+ from parser.segformer_parser import SegformerParser
56
+ parser = SegformerParser(SEGFORMER_MODEL)
57
+
58
+ # Initialize Stable Diffusion model
59
+ logger.info("Initializing Stable Diffusion model...")
60
+ model = StableDiffusionInpaintPipeline.from_pretrained(
61
+ STABLE_DIFFUSION_MODEL,
62
+ safety_checker=None,
63
+ revision="fp16" if device == "cuda" else None,
64
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32
65
+ ).to(device)
66
+
67
+ # Initialize inpainter
68
+ logger.info("Initializing inpainter...")
69
+ inpainter = ClothingInpainter(model=model, parser=parser)
70
+
71
+ logger.info(f"Application initialized in {time.time() - start_time:.2f} seconds")
72
+ except Exception as e:
73
+ logger.error(f"Error initializing application: {str(e)}")
74
+ raise e
75
+
76
+ class ClothingInpainter:
77
+ def __init__(self, model_path=None, model=None, parser=None):
78
+ self.device = get_device()
79
+
80
+ if model_path is None and model is None:
81
+ raise ValueError('No model provided!')
82
+ if model_path is not None:
83
+ self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
84
+ model_path,
85
+ safety_checker=None,
86
+ revision="fp16" if self.device == "cuda" else None,
87
+ torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
88
+ ).to(self.device)
89
+ else:
90
+ self.pipe = model
91
+
92
+ self.parser = parser
93
+
94
+ def make_square(self, im, min_size=256, fill_color=(0, 0, 0, 0)):
95
+ x, y = im.size
96
+ size = max(min_size, x, y)
97
+ new_im = PIL.Image.new('RGBA', (size, size), fill_color)
98
+ new_im.paste(im, (int((size - x) / 2), int((size - y) / 2)))
99
+ return new_im.convert('RGB')
100
+
101
+ def unmake_square(self, init_im, op_im, min_size=256, rs_size=512):
102
+ x, y = init_im.size
103
+ size = max(min_size, x, y)
104
+ factor = rs_size/size
105
+ return op_im.crop((int((size-x) * factor / 2), int((size-y) * factor / 2),\
106
+ int((size+x) * factor / 2), int((size+y) * factor / 2)))
107
+
108
+ def inpaint(self, prompt, init_image, parser=None) -> dict:
109
+ image = self.make_square(init_image).resize((512,512))
110
+
111
+ if self.parser is not None:
112
+ mask = self.parser.get_image_mask(image)
113
+ mask = mask.resize((512,512))
114
+ elif parser is not None:
115
+ mask = parser.get_image_mask(image)
116
+ mask = mask.resize((512,512))
117
+ else:
118
+ raise ValueError('Image Parser is Missing')
119
+ logger.info(f'[generated required mask(s) at {time.time()}]')
120
+
121
+ # Run the model
122
+ guidance_scale=7.5
123
+ num_samples = 3
124
+ with autocast("cuda"), torch.inference_mode():
125
+ images = self.pipe(
126
+ num_inference_steps = 50,
127
+ prompt=prompt['pos'],
128
+ image=image,
129
+ mask_image=mask,
130
+ guidance_scale=guidance_scale,
131
+ num_images_per_prompt=num_samples,
132
+ ).images
133
+
134
+ images_output = []
135
+ for img in images:
136
+ ch = PIL.Image.composite(img,image, mask.convert('L'))
137
+ fin_img = self.unmake_square(init_image, ch)
138
+ images_output.append(fin_img)
139
+
140
+ return images_output
141
+
142
+ def process_image(prompt, image):
143
+ start_time = time.time()
144
+ logger.info(f"Processing new request - Prompt: {prompt}, Image size: {image.size if image else 'None'}")
145
+
146
+ try:
147
+ if image is None:
148
+ logger.error("No image provided")
149
+ raise gr.Error("Please upload an image")
150
+ if not prompt:
151
+ logger.error("No prompt provided")
152
+ raise gr.Error("Please enter a prompt")
153
+
154
+ prompt_dict = {'pos': prompt}
155
+ logger.info("Starting inpainting process")
156
+ images = inpainter.inpaint(prompt_dict, image)
157
+
158
+ if not images:
159
+ logger.error("Inpainting failed to produce results")
160
+ raise gr.Error("Failed to generate images. Please try again.")
161
+
162
+ logger.info(f"Request processed in {time.time() - start_time:.2f} seconds")
163
+ return images
164
+ except Exception as e:
165
+ logger.error(f"Error processing image: {str(e)}")
166
+ raise gr.Error(f"Error processing image: {str(e)}")
167
+
168
+ # Initialize the model
169
+ init()
170
+
171
+ # Create Gradio interface
172
+ with gr.Blocks(title="Looks.Studio - AI Clothing Inpainting") as demo:
173
+ gr.Markdown("# Looks.Studio - AI Clothing Inpainting")
174
+ gr.Markdown("Upload an image and describe the clothing you want to generate")
175
+
176
+ with gr.Row():
177
+ with gr.Column():
178
+ input_image = gr.Image(
179
+ type="pil",
180
+ label="Upload Image",
181
+ height=512
182
+ )
183
+ prompt = gr.Textbox(label="Describe the clothing you want to generate")
184
+ generate_btn = gr.Button("Generate")
185
+
186
+ with gr.Column():
187
+ gallery = gr.Gallery(
188
+ label="Generated Images",
189
+ show_label=False,
190
+ columns=2,
191
+ height=512
192
+ )
193
+
194
+ generate_btn.click(
195
+ fn=process_image,
196
+ inputs=[prompt, input_image],
197
+ outputs=gallery
198
+ )
199
+
200
+ if __name__ == "__main__":
201
+ demo.launch(share=True)
configs/configs.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "models": {
3
+ "schp": {
4
+ "download_url": "https://storage.googleapis.com/platform-ai-prod/looks_studio_data/schp_checkpoint.pth",
5
+ "path": "checkpoints/schp.pth"
6
+ },
7
+ "u2net": {
8
+ "download_url": "https://storage.googleapis.com/platform-ai-prod/looks_studio_data/cloth_segm_u2net_latest.pth",
9
+ "path": "checkpoints/cloth_segm_u2net_latest.pth"
10
+ },
11
+ "realesrgan": {
12
+ "download_url": "https://storage.googleapis.com/platform-ai-prod/looks_studio_data/RealESRGAN_x4plus.pth",
13
+ "path": "checkpoints/realesrgan_x4plus.pth"
14
+ },
15
+ "diffuser": {
16
+ "download_url": "https://storage.googleapis.com/platform-ai-prod/looks_studio_data/diffusers/stable_diffusion_2_checkpoint.tar",
17
+ "path": "checkpoints/stable_diffusion_2_inpainting"
18
+ }
19
+ }
20
+ }
download.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # In this file, we define download_model
2
+ # It runs during container build time to get model weights built into the container
3
+
4
+ import os
5
+ import wget
6
+ import json
7
+ import tarfile
8
+ import tempfile
9
+
10
+ def download_models(config):
11
+ # Download parser checkpoint
12
+ # wget.download(config['schp']['download_url'],
13
+ # os.path.join(os.path.dirname(__file__), config['schp']['path']))
14
+ wget.download(config['u2net']['download_url'],
15
+ os.path.join(os.path.dirname(__file__), config['u2net']['path']))
16
+
17
+ # Download Super resolution model
18
+ wget.download(config['realesrgan']['download_url'],
19
+ os.path.join(os.path.dirname(__file__), config['realesrgan']['path']))
20
+
21
+ # Download diffuser model checkpoint
22
+ _, local_file_name = tempfile.mkstemp()
23
+ local_file_name += '.tar'
24
+ wget.download(config['diffuser']['download_url'], local_file_name)
25
+ tar_file = tarfile.open(local_file_name)
26
+ tar_file.extractall(os.path.join(os.path.dirname(__file__),'checkpoints/'))
27
+
28
+ if __name__ == "__main__":
29
+ config_file = "configs/configs.json"
30
+ config_file = os.path.join(os.path.dirname(__file__), config_file)
31
+
32
+ with open(config_file) as fin:
33
+ config = json.load(fin)
34
+ download_models(config['models'])
download_models.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+
7
+ def download_models():
8
+ """Download required models for the application"""
9
+ start_time = time.time()
10
+ logger.info("Starting model download process")
11
+
12
+ try:
13
+ # Create models directory if it doesn't exist
14
+ os.makedirs("models", exist_ok=True)
15
+
16
+ logger.info(f"Model setup completed in {time.time() - start_time:.2f} seconds")
17
+ except Exception as e:
18
+ logger.error(f"Error in model setup: {str(e)}")
19
+ raise
20
+
21
+ if __name__ == "__main__":
22
+ download_models()
parser/__init__.py ADDED
File without changes
parser/schp_masker.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import sys
3
+ import numpy as np
4
+
5
+ import torch
6
+ from torch.utils import data
7
+ from torch.utils.data import DataLoader
8
+ import torchvision.transforms as transforms
9
+
10
+ from PIL import Image
11
+ from collections import OrderedDict
12
+
13
+ sys.path.insert(1, './schp')
14
+ from utils.transforms import get_affine_transform
15
+ import networks
16
+ from utils.transforms import transform_logits
17
+
18
+ class PILImageDataset(data.Dataset):
19
+ def __init__(self, img_lst=[], input_size=[512, 512], transform=None):
20
+ self.img_lst = img_lst
21
+ self.input_size = input_size
22
+ self.transform = transform
23
+ self.aspect_ratio = input_size[1] * 1.0 / input_size[0]
24
+ self.input_size = np.asarray(input_size)
25
+
26
+ def __len__(self):
27
+ return len(self.img_lst)
28
+
29
+ def _box2cs(self, box):
30
+ x, y, w, h = box[:4]
31
+ return self._xywh2cs(x, y, w, h)
32
+
33
+ def _xywh2cs(self, x, y, w, h):
34
+ center = np.zeros((2), dtype=np.float32)
35
+ center[0] = x + w * 0.5
36
+ center[1] = y + h * 0.5
37
+ if w > self.aspect_ratio * h:
38
+ h = w * 1.0 / self.aspect_ratio
39
+ elif w < self.aspect_ratio * h:
40
+ w = h * self.aspect_ratio
41
+ scale = np.array([w, h], dtype=np.float32)
42
+ return center, scale
43
+
44
+ def __getitem__(self, index):
45
+ img = np.array(self.img_lst[index])[:,:,::-1]
46
+ h, w, _ = img.shape
47
+
48
+ # Get person center and scale
49
+ person_center, s = self._box2cs([0, 0, w - 1, h - 1])
50
+ r = 0
51
+ trans = get_affine_transform(person_center, s, r, self.input_size)
52
+ input = cv2.warpAffine(
53
+ img,
54
+ trans,
55
+ (int(self.input_size[1]), int(self.input_size[0])),
56
+ flags=cv2.INTER_LINEAR,
57
+ borderMode=cv2.BORDER_CONSTANT,
58
+ borderValue=(0, 0, 0))
59
+
60
+ input = self.transform(input)
61
+ meta = {
62
+ 'center': person_center,
63
+ 'height': h,
64
+ 'width': w,
65
+ 'scale': s,
66
+ 'rotation': r
67
+ }
68
+
69
+ return input, meta
70
+
71
+ PALLETE_DICT = {
72
+ 'Background': [],
73
+ 'Face': [],
74
+ 'Upper-clothes':[],
75
+ 'Dress':[],
76
+ 'Coat':[],
77
+ 'Soaks':[],
78
+ 'Pants':[],
79
+ 'Jumpsuits':[],
80
+ 'Scarf':[],
81
+ 'Skirt':[],
82
+ 'Arm':[],
83
+ 'Leg':[],
84
+ 'Shoe':[]
85
+ }
86
+
87
+ val_list = [[0],[1,4,13],[5],[6],[7],[8],[9],[10],[11],[12],[14,15],[16,17],[18,19]]
88
+ for c,j in enumerate(PALLETE_DICT.keys()):
89
+ val = val_list[c]
90
+ pallete = []
91
+ for i in range(60):
92
+ if len(val) == 1:
93
+ if (i >= (val[0]*3)) & (i < ((val[0]+1)*3)):
94
+ pallete.append(255)
95
+ else:
96
+ pallete.append(0)
97
+ if len(val) == 2:
98
+ if (i >= (val[0]*3)) & (i < ((val[0]+1)*3)) or (i >= (val[1]*3)) & (i < ((val[1]+1)*3)):
99
+ pallete.append(255)
100
+ else:
101
+ pallete.append(0)
102
+ if len(val) == 3:
103
+ if (i >= (val[0]*3)) & (i < ((val[0]+1)*3)) or (i >= (val[1]*3)) & (i < ((val[1]+1)*3)) or (i >= (val[2]*3)) & (i < ((val[2]+1)*3)):
104
+ pallete.append(255)
105
+ else:
106
+ pallete.append(0)
107
+
108
+ PALLETE_DICT[j] = pallete
109
+
110
+
111
+ DATASET_SETTINGS = {
112
+ 'lip': {
113
+ 'input_size': [473, 473],
114
+ 'num_classes': 20,
115
+ 'label': ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat',
116
+ 'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm',
117
+ 'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe']
118
+ },
119
+ 'atr': {
120
+ 'input_size': [512, 512],
121
+ 'num_classes': 18,
122
+ 'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt',
123
+ 'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf']
124
+ },
125
+ 'pascal': {
126
+ 'input_size': [512, 512],
127
+ 'num_classes': 7,
128
+ 'label': ['Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'],
129
+ }
130
+ }
131
+
132
+
133
+
134
+ class SCHPParser:
135
+ def __init__(self, checkpoint_path, dataset_settings):
136
+ self.cp_path = checkpoint_path
137
+ self.ops = []
138
+ self.num_classes = dataset_settings['lip']['num_classes']
139
+ self.input_size = dataset_settings['lip']['input_size']
140
+ self.label = dataset_settings['lip']['label']
141
+ self.pallete_dict = PALLETE_DICT
142
+
143
+
144
+ self.img_transforms = transforms.Compose([
145
+ transforms.ToTensor(),
146
+ transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
147
+ ])
148
+
149
+ self.model = self.load_model()
150
+
151
+
152
+ def load_model(self):
153
+ model = networks.init_model('resnet101', num_classes=self.num_classes, pretrained=None)
154
+ state_dict = torch.load(self.cp_path)['state_dict']
155
+ new_state_dict = OrderedDict()
156
+ for k, v in state_dict.items():
157
+ name = k[7:] # remove `module.`
158
+ new_state_dict[name] = v
159
+ model.load_state_dict(new_state_dict)
160
+ model.cuda()
161
+ model.eval()
162
+ return model
163
+
164
+ def create_dataloader(self, img_lst):
165
+ dataset = PILImageDataset(img_lst, input_size=self.input_size, transform=self.img_transforms)
166
+ # dataset = SimpleFolderDataset('inputs',input_size, transform)
167
+ dataloader = DataLoader(dataset)
168
+ return dataloader
169
+
170
+ def get_image_masks(self, img_lst):
171
+ print("Evaluating total class number {} with {}".format(self.num_classes, self.label))
172
+ dataloader = self.create_dataloader(img_lst)
173
+ with torch.no_grad():
174
+ for batch in dataloader:
175
+ op_dict = {}
176
+ image, meta = batch
177
+ c = meta['center'].numpy()[0]
178
+ s = meta['scale'].numpy()[0]
179
+ w = meta['width'].numpy()[0]
180
+ h = meta['height'].numpy()[0]
181
+
182
+ output = self.model(image.cuda())
183
+ upsample = torch.nn.Upsample(size=self.input_size, mode='bilinear', align_corners=True)
184
+ upsample_output = upsample(output[0][-1][0].unsqueeze(0))
185
+ upsample_output = upsample_output.squeeze()
186
+ upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC
187
+
188
+ logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=self.input_size)
189
+ parsing_result = np.argmax(logits_result, axis=2)
190
+ output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
191
+ for loc, key in enumerate(self.pallete_dict.keys()):
192
+ output_img.putpalette(self.pallete_dict[key])
193
+ op_dict.update({
194
+ key: output_img.convert('L')
195
+ })
196
+ self.ops.append(op_dict)
197
+ return self.ops
198
+
199
+
200
+
parser/segformer_parser.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from PIL import Image
4
+ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
5
+ import torch.nn.functional as F
6
+ import logging
7
+ import time
8
+ from typing import Tuple, Optional
9
+
10
+ logger = logging.getLogger('looks.studio.segformer')
11
+
12
+ class SegformerParser:
13
+ def __init__(self, model_path="mattmdjaga/segformer_b2_clothes"):
14
+ self.start_time = time.time()
15
+ logger.info(f"Initializing SegformerParser with model: {model_path}")
16
+
17
+ try:
18
+ self.processor = SegformerImageProcessor.from_pretrained(model_path)
19
+ self.model = AutoModelForSemanticSegmentation.from_pretrained(model_path)
20
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ logger.info(f"Using device: {self.device}")
22
+ self.model.to(self.device)
23
+
24
+ # Define clothing-related labels
25
+ self.clothing_labels = {
26
+ 4: "upper-clothes",
27
+ 5: "skirt",
28
+ 6: "pants",
29
+ 7: "dress",
30
+ 8: "belt",
31
+ 9: "left-shoe",
32
+ 10: "right-shoe",
33
+ 14: "left-arm",
34
+ 15: "right-arm",
35
+ 17: "scarf"
36
+ }
37
+
38
+ logger.info(f"SegformerParser initialized in {time.time() - self.start_time:.2f} seconds")
39
+ except Exception as e:
40
+ logger.error(f"Failed to initialize SegformerParser: {str(e)}")
41
+ raise
42
+
43
+ def _resize_image(self, image: Image.Image, max_size: int = 1024) -> Tuple[Image.Image, float]:
44
+ """Resize image while maintaining aspect ratio if it exceeds max_size"""
45
+ width, height = image.size
46
+ scale = 1.0
47
+
48
+ if width > max_size or height > max_size:
49
+ scale = max_size / max(width, height)
50
+ new_width = int(width * scale)
51
+ new_height = int(height * scale)
52
+ logger.info(f"Resizing image from {width}x{height} to {new_width}x{new_height}")
53
+ image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
54
+
55
+ return image, scale
56
+
57
+ def _validate_image(self, image: Image.Image) -> bool:
58
+ """Validate input image"""
59
+ if not isinstance(image, Image.Image):
60
+ logger.error("Input is not a PIL Image")
61
+ return False
62
+
63
+ if image.mode not in ['RGB', 'RGBA']:
64
+ logger.error(f"Unsupported image mode: {image.mode}")
65
+ return False
66
+
67
+ width, height = image.size
68
+ if width < 64 or height < 64:
69
+ logger.error(f"Image too small: {width}x{height}")
70
+ return False
71
+
72
+ if width > 4096 or height > 4096:
73
+ logger.error(f"Image too large: {width}x{height}")
74
+ return False
75
+
76
+ return True
77
+
78
+ def get_image_mask(self, image: Image.Image) -> Optional[Image.Image]:
79
+ """Generate segmentation mask for clothing"""
80
+ start_time = time.time()
81
+ logger.info(f"Starting segmentation for image size: {image.size}")
82
+
83
+ try:
84
+ # Validate input image
85
+ if not self._validate_image(image):
86
+ return None
87
+
88
+ # Convert RGBA to RGB if necessary
89
+ if image.mode == 'RGBA':
90
+ logger.info("Converting RGBA to RGB")
91
+ image = image.convert('RGB')
92
+
93
+ # Resize image if too large
94
+ image, scale = self._resize_image(image)
95
+
96
+ # Process the image
97
+ logger.info("Processing image with Segformer")
98
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device)
99
+
100
+ # Get predictions
101
+ with torch.no_grad():
102
+ outputs = self.model(**inputs)
103
+ logits = outputs.logits.cpu()
104
+
105
+ # Upsample logits to original image size
106
+ upsampled_logits = F.interpolate(
107
+ logits,
108
+ size=image.size[::-1],
109
+ mode="bilinear",
110
+ align_corners=False,
111
+ )
112
+
113
+ # Get the segmentation mask
114
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
115
+
116
+ # Create a binary mask for clothing
117
+ mask = torch.zeros_like(pred_seg)
118
+ for label_id in self.clothing_labels.keys():
119
+ mask[pred_seg == label_id] = 255
120
+
121
+ # Convert to PIL Image
122
+ mask = Image.fromarray(mask.numpy().astype(np.uint8))
123
+
124
+ # Resize mask back to original size if needed
125
+ if scale != 1.0:
126
+ original_size = (int(image.size[0] / scale), int(image.size[1] / scale))
127
+ logger.info(f"Resizing mask back to original size: {original_size}")
128
+ mask = mask.resize(original_size, Image.Resampling.NEAREST)
129
+
130
+ logger.info(f"Segmentation completed in {time.time() - start_time:.2f} seconds")
131
+ return mask
132
+
133
+ except Exception as e:
134
+ logger.error(f"Error during segmentation: {str(e)}")
135
+ return None
136
+
137
+ def get_all_masks(self, image: Image.Image) -> dict:
138
+ """Return a dict of binary masks for each clothing part label."""
139
+ start_time = time.time()
140
+ logger.info(f"Starting per-part segmentation for image size: {image.size}")
141
+ masks = {}
142
+ try:
143
+ # Validate input image
144
+ if not self._validate_image(image):
145
+ return masks
146
+
147
+ # Convert RGBA to RGB if necessary
148
+ if image.mode == 'RGBA':
149
+ logger.info("Converting RGBA to RGB")
150
+ image = image.convert('RGB')
151
+
152
+ # Resize image if too large
153
+ image, scale = self._resize_image(image)
154
+
155
+ # Process the image
156
+ logger.info("Processing image with Segformer for all masks")
157
+ inputs = self.processor(images=image, return_tensors="pt").to(self.device)
158
+
159
+ # Get predictions
160
+ with torch.no_grad():
161
+ outputs = self.model(**inputs)
162
+ logits = outputs.logits.cpu()
163
+ upsampled_logits = F.interpolate(
164
+ logits,
165
+ size=image.size[::-1],
166
+ mode="bilinear",
167
+ align_corners=False,
168
+ )
169
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
170
+
171
+ # For each clothing label, create a binary mask
172
+ for label_id, part_name in self.clothing_labels.items():
173
+ mask = (pred_seg == label_id).numpy().astype(np.uint8) * 255
174
+ mask_img = Image.fromarray(mask)
175
+ # Resize mask back to original size if needed
176
+ if scale != 1.0:
177
+ original_size = (int(image.size[0] / scale), int(image.size[1] / scale))
178
+ mask_img = mask_img.resize(original_size, Image.Resampling.NEAREST)
179
+ masks[part_name] = mask_img
180
+
181
+ logger.info(f"Per-part segmentation completed in {time.time() - start_time:.2f} seconds")
182
+ return masks
183
+ except Exception as e:
184
+ logger.error(f"Error during per-part segmentation: {str(e)}")
185
+ return masks
parser/u2net_parser.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ # from tqdm import tqdm
3
+ from PIL import Image
4
+ import numpy as np
5
+ import sys
6
+
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torchvision.transforms as transforms
10
+
11
+ from .u2net_cloth_seg.data.base_dataset import Normalize_image
12
+ from .u2net_cloth_seg.utils.saving_utils import load_checkpoint_mgpu
13
+
14
+ from .u2net_cloth_seg.networks import U2NET
15
+
16
+ class U2NETParser:
17
+ def __init__(self, checkpoint_path):
18
+ self.cp_path = checkpoint_path
19
+ self.img_transforms = transforms.Compose([
20
+ transforms.ToTensor(),
21
+ Normalize_image(0.5, 0.5)
22
+ ])
23
+ self.model = self.load_model()
24
+
25
+
26
+ def load_model(self):
27
+ model = U2NET(in_ch=3, out_ch=4)
28
+ model = load_checkpoint_mgpu(model, self.cp_path)
29
+ model = model.to("cuda")
30
+ model = model.eval()
31
+ return model
32
+
33
+ def get_image_mask(self, img):
34
+ # print("Evaluating total class number {} with {}".format(self.num_classes, self.label))
35
+ img_size = img.size
36
+ img = img.resize((768, 768), Image.BICUBIC)
37
+ image_tensor = self.img_transforms(img)
38
+ image_tensor = torch.unsqueeze(image_tensor, 0)
39
+
40
+ with torch.no_grad():
41
+ output_tensor = self.model(image_tensor.to("cuda"))
42
+
43
+ output_tensor = F.log_softmax(output_tensor[0], dim=1)
44
+ output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
45
+ output_tensor = torch.squeeze(output_tensor, dim=0)
46
+ output_tensor = torch.squeeze(output_tensor, dim=0)
47
+ output_arr = output_tensor.cpu().numpy()
48
+
49
+ output_arr[output_arr != 1] = 0
50
+ output_arr[output_arr == 1] = 255
51
+
52
+ output_img = Image.fromarray(output_arr.astype('uint8'), mode='L')
53
+ output_img = output_img.resize(img_size, Image.BICUBIC)
54
+
55
+ return output_img
requirements.txt ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sanic>=25.3.0
2
+ git+https://github.com/huggingface/diffusers.git#egg=diffusers
3
+ transformers>=4.30.0
4
+ scipy>=1.11.0
5
+ opencv-python>=4.8.0
6
+ wget
7
+ # ninja
8
+ accelerate>=0.24.0
9
+ basicsr>=1.4.2
10
+ ftfy>=6.1.1
11
+ # bitsandbytes
12
+ gradio>=3.50.0
13
+ # natsort
14
+ # https://github.com/metrolobo/xformers_wheels/releases/download/1d31a3ac_various_6/xformers-0.0.14.dev0-cp37-cp37m-linux_x86_64.whl
15
+ torch>=2.0.0
16
+ diffusers>=0.19.0
17
+ Pillow>=9.0.0
18
+ requests>=2.28.0
19
+ numpy>=1.24.0
20
+ huggingface_hub>=0.16.0
21
+ matplotlib>=3.7.0 # For visualization
server.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Do not edit if deploying to Banana Serverless
2
+ # This file is boilerplate for the http server, and follows a strict interface.
3
+
4
+ # Instead, edit the init() and inference() functions in app.py
5
+
6
+ from sanic import Sanic, response
7
+ import subprocess
8
+ import app as user_src
9
+
10
+ # We do the model load-to-GPU step on server startup
11
+ # so the model object is available globally for reuse
12
+ user_src.init()
13
+
14
+ # Create the http server app
15
+ server = Sanic("my_app")
16
+
17
+ # Healthchecks verify that the environment is correct on Banana Serverless
18
+ @server.route('/healthcheck', methods=["GET"])
19
+ def healthcheck(request):
20
+ # dependency free way to check if GPU is visible
21
+ gpu = False
22
+ out = subprocess.run("nvidia-smi", shell=True)
23
+ if out.returncode == 0: # success state on shell command
24
+ gpu = True
25
+
26
+ return response.json({"state": "healthy", "gpu": gpu})
27
+
28
+ # Inference POST handler at '/' is called for every http call from Banana
29
+ @server.route('/', methods=["POST"])
30
+ def inference(request):
31
+ try:
32
+ model_inputs = response.json.loads(request.json)
33
+ except:
34
+ model_inputs = request.json
35
+
36
+ output = user_src.inference(model_inputs)
37
+
38
+ return response.json(output)
39
+
40
+
41
+ if __name__ == '__main__':
42
+ server.run(host='0.0.0.0', port=8000, workers=1)
upscaler/__init__.py ADDED
File without changes
upscaler/realesrgan_upscaler.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from basicsr.archs.rrdbnet_arch import RRDBNet
2
+ from .real_esrgan.realesrgan import RealESRGANer
3
+
4
+ import cv2
5
+ import numpy as np
6
+ from PIL import Image
7
+
8
+
9
+ class RealESRGAN:
10
+ def __init__(self, checkpoint_path):
11
+
12
+ self.netscale = 4
13
+
14
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
15
+
16
+ self.upsampler = RealESRGANer(
17
+ scale=self.netscale,
18
+ model_path=checkpoint_path,
19
+ dni_weight=None,
20
+ model=model,
21
+ tile=0,
22
+ tile_pad=10,
23
+ pre_pad=0,
24
+ half=True)
25
+
26
+ def upscale(self, pil_image, scale_factor=3):
27
+ cv2_img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
28
+ op, _ = self.upsampler.enhance(cv2_img, outscale=scale_factor)
29
+ pil_image_fin = Image.fromarray(cv2.cvtColor(op, cv2.COLOR_BGR2RGB))
30
+ return pil_image_fin