Spaces:
Running
Running
Bismay
commited on
Commit
·
475e066
0
Parent(s):
Initial commit
Browse files- .gitignore +178 -0
- Dockerfile +41 -0
- LICENSE +21 -0
- README.md +38 -0
- app.py +429 -0
- colab_demo.py +201 -0
- configs/configs.json +20 -0
- download.py +34 -0
- download_models.py +22 -0
- parser/__init__.py +0 -0
- parser/schp_masker.py +200 -0
- parser/segformer_parser.py +185 -0
- parser/u2net_parser.py +55 -0
- requirements.txt +21 -0
- server.py +42 -0
- upscaler/__init__.py +0 -0
- upscaler/realesrgan_upscaler.py +30 -0
.gitignore
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Python
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
*.so
|
6 |
+
.Python
|
7 |
+
env/
|
8 |
+
build/
|
9 |
+
develop-eggs/
|
10 |
+
dist/
|
11 |
+
downloads/
|
12 |
+
eggs/
|
13 |
+
.eggs/
|
14 |
+
lib/
|
15 |
+
lib64/
|
16 |
+
parts/
|
17 |
+
sdist/
|
18 |
+
var/
|
19 |
+
*.egg-info/
|
20 |
+
.installed.cfg
|
21 |
+
*.egg
|
22 |
+
.env
|
23 |
+
.venv
|
24 |
+
venv/
|
25 |
+
ENV/
|
26 |
+
|
27 |
+
# IDE
|
28 |
+
.idea/
|
29 |
+
.vscode/
|
30 |
+
*.swp
|
31 |
+
*.swo
|
32 |
+
.DS_Store
|
33 |
+
|
34 |
+
# Binary files and assets
|
35 |
+
parser/u2net_cloth_seg/assets/*.png
|
36 |
+
upscaler/real_esrgan/assets/*.png
|
37 |
+
upscaler/real_esrgan/assets/*.jpg
|
38 |
+
upscaler/real_esrgan/inputs/*.png
|
39 |
+
upscaler/real_esrgan/inputs/video/*.mp4
|
40 |
+
upscaler/real_esrgan/tests/data/gt.lmdb/
|
41 |
+
upscaler/real_esrgan/tests/data/gt/*.png
|
42 |
+
|
43 |
+
# Models
|
44 |
+
models/
|
45 |
+
*.pth
|
46 |
+
*.ckpt
|
47 |
+
*.safetensors
|
48 |
+
|
49 |
+
# Logs
|
50 |
+
*.log
|
51 |
+
logs/
|
52 |
+
|
53 |
+
# Temporary files
|
54 |
+
*.tmp
|
55 |
+
*.temp
|
56 |
+
*.bak
|
57 |
+
*.backup
|
58 |
+
|
59 |
+
# System files
|
60 |
+
.DS_Store
|
61 |
+
Thumbs.db
|
62 |
+
|
63 |
+
# Distribution / packaging
|
64 |
+
.Python
|
65 |
+
build/
|
66 |
+
develop-eggs/
|
67 |
+
dist/
|
68 |
+
downloads/
|
69 |
+
eggs/
|
70 |
+
.eggs/
|
71 |
+
lib/
|
72 |
+
lib64/
|
73 |
+
parts/
|
74 |
+
sdist/
|
75 |
+
var/
|
76 |
+
wheels/
|
77 |
+
pip-wheel-metadata/
|
78 |
+
share/python-wheels/
|
79 |
+
*.egg-info/
|
80 |
+
.installed.cfg
|
81 |
+
*.egg
|
82 |
+
MANIFEST
|
83 |
+
|
84 |
+
# PyInstaller
|
85 |
+
# Usually these files are written by a python script from a template
|
86 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
87 |
+
*.manifest
|
88 |
+
*.spec
|
89 |
+
|
90 |
+
# Installer logs
|
91 |
+
pip-log.txt
|
92 |
+
pip-delete-this-directory.txt
|
93 |
+
|
94 |
+
# Unit test / coverage reports
|
95 |
+
htmlcov/
|
96 |
+
.tox/
|
97 |
+
.nox/
|
98 |
+
.coverage
|
99 |
+
.coverage.*
|
100 |
+
.cache
|
101 |
+
nosetests.xml
|
102 |
+
coverage.xml
|
103 |
+
*.cover
|
104 |
+
*.py,cover
|
105 |
+
.hypothesis/
|
106 |
+
.pytest_cache/
|
107 |
+
|
108 |
+
# Translations
|
109 |
+
*.mo
|
110 |
+
*.pot
|
111 |
+
|
112 |
+
# Django stuff:
|
113 |
+
*.log
|
114 |
+
local_settings.py
|
115 |
+
db.sqlite3
|
116 |
+
db.sqlite3-journal
|
117 |
+
|
118 |
+
# Flask stuff:
|
119 |
+
instance/
|
120 |
+
.webassets-cache
|
121 |
+
|
122 |
+
# Scrapy stuff:
|
123 |
+
.scrapy
|
124 |
+
|
125 |
+
# Sphinx documentation
|
126 |
+
docs/_build/
|
127 |
+
|
128 |
+
# PyBuilder
|
129 |
+
target/
|
130 |
+
|
131 |
+
# Jupyter Notebook
|
132 |
+
.ipynb_checkpoints
|
133 |
+
|
134 |
+
# IPython
|
135 |
+
profile_default/
|
136 |
+
ipython_config.py
|
137 |
+
|
138 |
+
# pyenv
|
139 |
+
.python-version
|
140 |
+
|
141 |
+
# pipenv
|
142 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
143 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
144 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
145 |
+
# install all needed dependencies.
|
146 |
+
#Pipfile.lock
|
147 |
+
|
148 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
149 |
+
__pypackages__/
|
150 |
+
|
151 |
+
# Celery stuff
|
152 |
+
celerybeat-schedule
|
153 |
+
celerybeat.pid
|
154 |
+
|
155 |
+
# SageMath parsed files
|
156 |
+
*.sage.py
|
157 |
+
|
158 |
+
# Spyder project settings
|
159 |
+
.spyderproject
|
160 |
+
.spyproject
|
161 |
+
|
162 |
+
# Rope project settings
|
163 |
+
.ropeproject
|
164 |
+
|
165 |
+
# mkdocs documentation
|
166 |
+
/site
|
167 |
+
|
168 |
+
# mypy
|
169 |
+
.mypy_cache/
|
170 |
+
.dmypy.json
|
171 |
+
dmypy.json
|
172 |
+
|
173 |
+
# Pyre type checker
|
174 |
+
.pyre/
|
175 |
+
|
176 |
+
# Additional exclusions
|
177 |
+
*.swp
|
178 |
+
*.swo
|
Dockerfile
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Must use a Cuda version 11+
|
2 |
+
# FROM pytorch/pytorch:1.11.0-cuda11.3-cudnn8-runtime
|
3 |
+
FROM pytorch/pytorch:1.13.0-cuda11.6-cudnn8-runtime
|
4 |
+
|
5 |
+
WORKDIR /
|
6 |
+
COPY ./parser /parser
|
7 |
+
COPY ./configs /configs
|
8 |
+
RUN mkdir /checkpoints
|
9 |
+
# Install git
|
10 |
+
RUN apt-get update && apt-get install -y --no-install-recommends \
|
11 |
+
git \
|
12 |
+
build-essential
|
13 |
+
|
14 |
+
# Install python packages
|
15 |
+
RUN pip3 install --upgrade pip
|
16 |
+
ADD requirements.txt requirements.txt
|
17 |
+
RUN pip3 install -r requirements.txt
|
18 |
+
|
19 |
+
# Install cv2 dependencies
|
20 |
+
RUN apt-get update
|
21 |
+
RUN apt-get install ffmpeg libsm6 libxext6 -y
|
22 |
+
|
23 |
+
# We add the banana boilerplate here
|
24 |
+
ADD server.py .
|
25 |
+
|
26 |
+
# Add your model weight files
|
27 |
+
# (in this case we have a python script)
|
28 |
+
ADD download.py .
|
29 |
+
|
30 |
+
# Add your custom app code, init() and inference()
|
31 |
+
ADD app.py .
|
32 |
+
|
33 |
+
ENV PYTHONPATH "${PYTHONPATH}:/parser:/upscaler"
|
34 |
+
|
35 |
+
#Alternative to using build args, you can put your token in the next line
|
36 |
+
#ENV HF_AUTH_TOKEN={token}
|
37 |
+
RUN python3 download.py
|
38 |
+
|
39 |
+
EXPOSE 8000
|
40 |
+
|
41 |
+
CMD python3 -u server.py
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2022 Banana
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: ClothQuill - AI Clothing Inpainting
|
3 |
+
emoji: 👕
|
4 |
+
colorFrom: blue
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.25.1
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
---
|
11 |
+
|
12 |
+
# ClothQuill - AI Clothing Inpainting
|
13 |
+
|
14 |
+
This Space allows you to inpaint clothing on images using Stable Diffusion. Upload an image, provide a prompt describing the clothing you want to generate, and get multiple inpainted results.
|
15 |
+
|
16 |
+
## How to Use
|
17 |
+
|
18 |
+
1. Upload an image containing a person
|
19 |
+
2. Enter a prompt describing the clothing you want to generate
|
20 |
+
3. Click "Generate" to get multiple inpainted results
|
21 |
+
4. Download your favorite result
|
22 |
+
|
23 |
+
## Examples
|
24 |
+
|
25 |
+
- Prompt: "A stylish black leather jacket"
|
26 |
+
- Prompt: "A formal blue suit with white shirt"
|
27 |
+
- Prompt: "A casual red hoodie"
|
28 |
+
|
29 |
+
## Technical Details
|
30 |
+
|
31 |
+
This Space uses:
|
32 |
+
- Stable Diffusion for inpainting
|
33 |
+
- U2NET for human parsing
|
34 |
+
- RealESRGAN for upscaling
|
35 |
+
|
36 |
+
## License
|
37 |
+
|
38 |
+
This project is licensed under the MIT License.
|
app.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import autocast
|
3 |
+
from diffusers import StableDiffusionInpaintPipeline
|
4 |
+
import gradio as gr
|
5 |
+
import traceback
|
6 |
+
import base64
|
7 |
+
from io import BytesIO
|
8 |
+
import os
|
9 |
+
# import sys
|
10 |
+
import PIL
|
11 |
+
import json
|
12 |
+
import requests
|
13 |
+
import logging
|
14 |
+
import time
|
15 |
+
import warnings
|
16 |
+
import numpy as np
|
17 |
+
from PIL import Image, ImageDraw
|
18 |
+
import cv2
|
19 |
+
warnings.filterwarnings("ignore")
|
20 |
+
|
21 |
+
# sys.path.insert(1, './parser')
|
22 |
+
|
23 |
+
# from parser.schp_masker import *
|
24 |
+
from parser.segformer_parser import SegformerParser
|
25 |
+
|
26 |
+
# Configure logging
|
27 |
+
logging.basicConfig(
|
28 |
+
level=logging.INFO,
|
29 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
30 |
+
)
|
31 |
+
logger = logging.getLogger('clothquill')
|
32 |
+
|
33 |
+
# Model paths
|
34 |
+
SEGFORMER_MODEL = "mattmdjaga/segformer_b2_clothes"
|
35 |
+
STABLE_DIFFUSION_MODEL = "stabilityai/stable-diffusion-2-inpainting"
|
36 |
+
|
37 |
+
# Global variables for models
|
38 |
+
parser = None
|
39 |
+
model = None
|
40 |
+
inpainter = None
|
41 |
+
original_image = None # Store the original uploaded image
|
42 |
+
|
43 |
+
# Color mapping for different clothing parts
|
44 |
+
CLOTHING_COLORS = {
|
45 |
+
'Background': (0, 0, 0, 0), # Transparent
|
46 |
+
'Hat': (255, 0, 0, 128), # Red
|
47 |
+
'Hair': (0, 255, 0, 128), # Green
|
48 |
+
'Glove': (0, 0, 255, 128), # Blue
|
49 |
+
'Sunglasses': (255, 255, 0, 128), # Yellow
|
50 |
+
'Upper-clothes': (255, 0, 255, 128), # Magenta
|
51 |
+
'Dress': (0, 255, 255, 128), # Cyan
|
52 |
+
'Coat': (128, 0, 0, 128), # Dark Red
|
53 |
+
'Socks': (0, 128, 0, 128), # Dark Green
|
54 |
+
'Pants': (0, 0, 128, 128), # Dark Blue
|
55 |
+
'Jumpsuits': (128, 128, 0, 128), # Dark Yellow
|
56 |
+
'Scarf': (128, 0, 128, 128), # Dark Magenta
|
57 |
+
'Skirt': (0, 128, 128, 128), # Dark Cyan
|
58 |
+
'Face': (192, 192, 192, 128), # Light Gray
|
59 |
+
'Left-arm': (64, 64, 64, 128), # Dark Gray
|
60 |
+
'Right-arm': (64, 64, 64, 128), # Dark Gray
|
61 |
+
'Left-leg': (32, 32, 32, 128), # Very Dark Gray
|
62 |
+
'Right-leg': (32, 32, 32, 128), # Very Dark Gray
|
63 |
+
'Left-shoe': (16, 16, 16, 128), # Almost Black
|
64 |
+
'Right-shoe': (16, 16, 16, 128), # Almost Black
|
65 |
+
}
|
66 |
+
|
67 |
+
def get_device():
|
68 |
+
if torch.cuda.is_available():
|
69 |
+
device = "cuda"
|
70 |
+
logger.info("Using GPU")
|
71 |
+
else:
|
72 |
+
device = "cpu"
|
73 |
+
logger.info("Using CPU")
|
74 |
+
return device
|
75 |
+
|
76 |
+
def init():
|
77 |
+
global parser
|
78 |
+
global model
|
79 |
+
global inpainter
|
80 |
+
|
81 |
+
start_time = time.time()
|
82 |
+
logger.info("Starting application initialization")
|
83 |
+
|
84 |
+
try:
|
85 |
+
device = get_device()
|
86 |
+
|
87 |
+
# Check if models directory exists
|
88 |
+
if not os.path.exists("models"):
|
89 |
+
logger.info("Creating models directory...")
|
90 |
+
from download_models import download_models
|
91 |
+
download_models()
|
92 |
+
|
93 |
+
# Initialize Segformer parser
|
94 |
+
logger.info("Initializing Segformer parser...")
|
95 |
+
parser = SegformerParser(SEGFORMER_MODEL)
|
96 |
+
|
97 |
+
# Initialize Stable Diffusion model
|
98 |
+
logger.info("Initializing Stable Diffusion model...")
|
99 |
+
model = StableDiffusionInpaintPipeline.from_pretrained(
|
100 |
+
STABLE_DIFFUSION_MODEL,
|
101 |
+
safety_checker=None,
|
102 |
+
revision="fp16" if device == "cuda" else None,
|
103 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
104 |
+
).to(device)
|
105 |
+
|
106 |
+
# Initialize inpainter
|
107 |
+
logger.info("Initializing inpainter...")
|
108 |
+
inpainter = ClothingInpainter(model=model, parser=parser)
|
109 |
+
|
110 |
+
logger.info(f"Application initialized in {time.time() - start_time:.2f} seconds")
|
111 |
+
except Exception as e:
|
112 |
+
logger.error(f"Error initializing application: {str(e)}")
|
113 |
+
raise e
|
114 |
+
|
115 |
+
class ClothingInpainter:
|
116 |
+
def __init__(self, model_path=None, model=None, parser=None):
|
117 |
+
self.device = get_device()
|
118 |
+
self.last_mask = None # Store the last generated mask
|
119 |
+
self.original_image = None # Store the original image
|
120 |
+
|
121 |
+
if model_path is None and model is None:
|
122 |
+
raise ValueError('No model provided!')
|
123 |
+
if model_path is not None:
|
124 |
+
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
125 |
+
model_path,
|
126 |
+
safety_checker=None,
|
127 |
+
revision="fp16" if self.device == "cuda" else None,
|
128 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
129 |
+
).to(self.device)
|
130 |
+
else:
|
131 |
+
self.pipe = model
|
132 |
+
|
133 |
+
self.parser = parser
|
134 |
+
|
135 |
+
def make_square(self, im, min_size=256, fill_color=(0, 0, 0, 0)):
|
136 |
+
x, y = im.size
|
137 |
+
size = max(min_size, x, y)
|
138 |
+
new_im = PIL.Image.new('RGBA', (size, size), fill_color)
|
139 |
+
new_im.paste(im, (int((size - x) / 2), int((size - y) / 2)))
|
140 |
+
return new_im.convert('RGB')
|
141 |
+
|
142 |
+
def unmake_square(self, init_im, op_im, min_size=256, rs_size=512):
|
143 |
+
x, y = init_im.size
|
144 |
+
size = max(min_size, x, y)
|
145 |
+
factor = rs_size/size
|
146 |
+
return op_im.crop((int((size-x) * factor / 2), int((size-y) * factor / 2),\
|
147 |
+
int((size+x) * factor / 2), int((size+y) * factor / 2)))
|
148 |
+
|
149 |
+
def visualize_segmentation(self, image, masks, selected_parts=None):
|
150 |
+
"""Visualize segmentation with colored overlays for selected parts and gray for unselected."""
|
151 |
+
# Always use original image if available
|
152 |
+
image_to_use = self.original_image if self.original_image is not None else image
|
153 |
+
|
154 |
+
# Create a copy of the original image
|
155 |
+
original_size = image_to_use.size
|
156 |
+
vis_image = image_to_use.copy().convert('RGBA')
|
157 |
+
|
158 |
+
# Create overlay at 512x512
|
159 |
+
overlay = Image.new('RGBA', (512, 512), (0, 0, 0, 0))
|
160 |
+
draw = ImageDraw.Draw(overlay)
|
161 |
+
|
162 |
+
# Draw each mask with its corresponding color
|
163 |
+
for part_name, mask in masks.items():
|
164 |
+
# Convert part name for color lookup
|
165 |
+
color_key = part_name.replace('-', ' ').title().replace(' ', '-')
|
166 |
+
is_selected = selected_parts and part_name in selected_parts
|
167 |
+
|
168 |
+
# If selected, use color (with fallback). If unselected, use faint gray
|
169 |
+
if is_selected:
|
170 |
+
color = CLOTHING_COLORS.get(color_key, (255, 0, 255, 128)) # Default to magenta if no color found
|
171 |
+
else:
|
172 |
+
color = (180, 180, 180, 80) # Faint gray for unselected
|
173 |
+
|
174 |
+
mask_array = np.array(mask)
|
175 |
+
coords = np.where(mask_array > 0)
|
176 |
+
for y, x in zip(coords[0], coords[1]):
|
177 |
+
draw.point((x, y), fill=color)
|
178 |
+
|
179 |
+
# Resize overlay to match original image size
|
180 |
+
overlay = overlay.resize(original_size, Image.Resampling.LANCZOS)
|
181 |
+
|
182 |
+
# Composite the overlay onto the original image
|
183 |
+
vis_image = Image.alpha_composite(vis_image, overlay)
|
184 |
+
return vis_image
|
185 |
+
|
186 |
+
def inpaint(self, prompt, init_image, selected_parts=None, dilation_iterations=2) -> dict:
|
187 |
+
image = self.make_square(init_image).resize((512,512))
|
188 |
+
|
189 |
+
if self.parser is not None:
|
190 |
+
masks = self.parser.get_all_masks(image)
|
191 |
+
masks = {k: v.resize((512,512)) for k, v in masks.items()}
|
192 |
+
else:
|
193 |
+
raise ValueError('Image Parser is Missing')
|
194 |
+
|
195 |
+
logger.info(f'[generated required mask(s) at {time.time()}]')
|
196 |
+
|
197 |
+
# Create combined mask for selected parts
|
198 |
+
if selected_parts:
|
199 |
+
combined_mask = Image.new('L', (512, 512), 0)
|
200 |
+
for part in selected_parts:
|
201 |
+
if part in masks:
|
202 |
+
mask_array = np.array(masks[part])
|
203 |
+
kernel = np.ones((5,5), np.uint8)
|
204 |
+
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
|
205 |
+
dilated_mask = Image.fromarray(dilated_mask)
|
206 |
+
combined_mask = Image.composite(
|
207 |
+
Image.new('L', (512, 512), 255),
|
208 |
+
combined_mask,
|
209 |
+
dilated_mask
|
210 |
+
)
|
211 |
+
else:
|
212 |
+
# If no parts selected, use all clothing parts
|
213 |
+
combined_mask = Image.new('L', (512, 512), 0)
|
214 |
+
for part, mask in masks.items():
|
215 |
+
if part in ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']:
|
216 |
+
mask_array = np.array(mask)
|
217 |
+
kernel = np.ones((5,5), np.uint8)
|
218 |
+
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
|
219 |
+
dilated_mask = Image.fromarray(dilated_mask)
|
220 |
+
combined_mask = Image.composite(
|
221 |
+
Image.new('L', (512, 512), 255),
|
222 |
+
combined_mask,
|
223 |
+
dilated_mask
|
224 |
+
)
|
225 |
+
|
226 |
+
# Run the model
|
227 |
+
guidance_scale=7.5
|
228 |
+
num_samples = 3
|
229 |
+
with autocast("cuda"), torch.inference_mode():
|
230 |
+
images = self.pipe(
|
231 |
+
num_inference_steps = 50,
|
232 |
+
prompt=prompt['pos'],
|
233 |
+
image=image,
|
234 |
+
mask_image=combined_mask,
|
235 |
+
guidance_scale=guidance_scale,
|
236 |
+
num_images_per_prompt=num_samples,
|
237 |
+
).images
|
238 |
+
|
239 |
+
images_output = []
|
240 |
+
for img in images:
|
241 |
+
ch = PIL.Image.composite(img, image, combined_mask)
|
242 |
+
fin_img = self.unmake_square(init_image, ch)
|
243 |
+
images_output.append(fin_img)
|
244 |
+
|
245 |
+
return images_output
|
246 |
+
|
247 |
+
def process_segmentation(image, dilation_iterations=2):
|
248 |
+
try:
|
249 |
+
if image is None:
|
250 |
+
raise gr.Error("Please upload an image")
|
251 |
+
|
252 |
+
# Store original image
|
253 |
+
inpainter.original_image = image.copy()
|
254 |
+
|
255 |
+
# Create a processing copy at 512x512
|
256 |
+
proc_image = image.resize((512, 512), Image.Resampling.LANCZOS)
|
257 |
+
|
258 |
+
# Get the main mask
|
259 |
+
all_masks = inpainter.parser.get_all_masks(proc_image)
|
260 |
+
if not all_masks:
|
261 |
+
logger.error("No clothing detected in the image")
|
262 |
+
raise gr.Error("No clothing detected in the image. Please try a different image.")
|
263 |
+
inpainter.last_mask = all_masks
|
264 |
+
# Only show main clothing parts for selection
|
265 |
+
main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']
|
266 |
+
masks = {k: v for k, v in all_masks.items() if k in main_parts}
|
267 |
+
vis_image = inpainter.visualize_segmentation(image, masks, selected_parts=None)
|
268 |
+
detected_parts = [k for k in masks.keys()]
|
269 |
+
return vis_image, gr.update(choices=detected_parts, value=[])
|
270 |
+
except gr.Error as e:
|
271 |
+
raise e
|
272 |
+
except Exception as e:
|
273 |
+
logger.error(f"Error processing segmentation: {str(e)}")
|
274 |
+
raise gr.Error("Error processing the image. Please try a different image.")
|
275 |
+
|
276 |
+
def update_dilation(image, selected_parts, dilation_iterations):
|
277 |
+
try:
|
278 |
+
if image is None or inpainter.last_mask is None:
|
279 |
+
return image
|
280 |
+
# Redilate all stored masks
|
281 |
+
main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']
|
282 |
+
masks = {}
|
283 |
+
for part in main_parts:
|
284 |
+
if part in inpainter.last_mask:
|
285 |
+
mask_array = np.array(inpainter.last_mask[part])
|
286 |
+
kernel = np.ones((5,5), np.uint8)
|
287 |
+
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
|
288 |
+
masks[part] = Image.fromarray(dilated_mask)
|
289 |
+
# Use original image for visualization
|
290 |
+
vis_image = inpainter.visualize_segmentation(inpainter.original_image, masks, selected_parts=selected_parts)
|
291 |
+
return vis_image
|
292 |
+
except Exception as e:
|
293 |
+
logger.error(f"Error updating dilation: {str(e)}")
|
294 |
+
return image
|
295 |
+
|
296 |
+
def process_image(prompt, image, selected_parts, dilation_iterations):
|
297 |
+
start_time = time.time()
|
298 |
+
logger.info(f"Processing new request - Prompt: {prompt}, Image size: {image.size if image else 'None'}")
|
299 |
+
|
300 |
+
try:
|
301 |
+
if image is None:
|
302 |
+
logger.error("No image provided")
|
303 |
+
raise gr.Error("Please upload an image")
|
304 |
+
if not prompt:
|
305 |
+
logger.error("No prompt provided")
|
306 |
+
raise gr.Error("Please enter a prompt")
|
307 |
+
if not selected_parts:
|
308 |
+
logger.error("No parts selected")
|
309 |
+
raise gr.Error("Please select at least one clothing part to modify")
|
310 |
+
|
311 |
+
prompt_dict = {'pos': prompt}
|
312 |
+
logger.info("Starting inpainting process")
|
313 |
+
|
314 |
+
# Generate inpainted images
|
315 |
+
# Convert selected_parts to lowercase/dash format
|
316 |
+
selected_parts = [p.lower() for p in selected_parts]
|
317 |
+
images = inpainter.inpaint(prompt_dict, image, selected_parts, dilation_iterations)
|
318 |
+
|
319 |
+
if not images:
|
320 |
+
logger.error("Inpainting failed to produce results")
|
321 |
+
raise gr.Error("Failed to generate images. Please try again.")
|
322 |
+
|
323 |
+
logger.info(f"Request processed in {time.time() - start_time:.2f} seconds")
|
324 |
+
return images
|
325 |
+
except Exception as e:
|
326 |
+
logger.error(f"Error processing image: {str(e)}")
|
327 |
+
raise gr.Error(f"Error processing image: {str(e)}")
|
328 |
+
|
329 |
+
def update_selected_parts(image, selected_parts, dilation_iterations):
|
330 |
+
try:
|
331 |
+
if image is None or inpainter.last_mask is None:
|
332 |
+
return image
|
333 |
+
main_parts = ['upper-clothes', 'dress', 'coat', 'pants', 'skirt']
|
334 |
+
masks = {}
|
335 |
+
for part in main_parts:
|
336 |
+
if part in inpainter.last_mask:
|
337 |
+
mask_array = np.array(inpainter.last_mask[part])
|
338 |
+
kernel = np.ones((5,5), np.uint8)
|
339 |
+
dilated_mask = cv2.dilate(mask_array, kernel, iterations=dilation_iterations)
|
340 |
+
masks[part] = Image.fromarray(dilated_mask)
|
341 |
+
# Lowercase the selected_parts for comparison
|
342 |
+
selected_parts = [p.lower() for p in selected_parts] if selected_parts else []
|
343 |
+
# Use original image for visualization
|
344 |
+
vis_image = inpainter.visualize_segmentation(inpainter.original_image, masks, selected_parts=selected_parts)
|
345 |
+
return vis_image
|
346 |
+
except Exception as e:
|
347 |
+
logger.error(f"Error updating selected parts: {str(e)}")
|
348 |
+
return image
|
349 |
+
|
350 |
+
# Initialize the model
|
351 |
+
init()
|
352 |
+
|
353 |
+
# Create Gradio interface
|
354 |
+
with gr.Blocks(title="ClothQuill - AI Clothing Inpainting") as demo:
|
355 |
+
gr.Markdown("# ClothQuill - AI Clothing Inpainting")
|
356 |
+
gr.Markdown("Upload an image to see segmented clothing parts, then select parts to modify and describe your changes")
|
357 |
+
|
358 |
+
with gr.Row():
|
359 |
+
with gr.Column():
|
360 |
+
input_image = gr.Image(
|
361 |
+
type="pil",
|
362 |
+
label="Upload Image",
|
363 |
+
scale=1, # This ensures the image maintains its aspect ratio
|
364 |
+
height=None # Allow dynamic height based on content
|
365 |
+
)
|
366 |
+
dilation_slider = gr.Slider(
|
367 |
+
minimum=0,
|
368 |
+
maximum=5,
|
369 |
+
value=2,
|
370 |
+
step=1,
|
371 |
+
label="Mask Dilation",
|
372 |
+
info="Adjust the mask dilation to control the area of modification"
|
373 |
+
)
|
374 |
+
selected_parts = gr.CheckboxGroup(
|
375 |
+
choices=[],
|
376 |
+
label="Select parts to modify",
|
377 |
+
value=[]
|
378 |
+
)
|
379 |
+
prompt = gr.Textbox(
|
380 |
+
label="Describe the clothing you want to generate",
|
381 |
+
placeholder="e.g., A stylish black leather jacket"
|
382 |
+
)
|
383 |
+
generate_btn = gr.Button("Generate")
|
384 |
+
|
385 |
+
with gr.Column():
|
386 |
+
gallery = gr.Gallery(
|
387 |
+
label="Generated Results",
|
388 |
+
show_label=False,
|
389 |
+
columns=2,
|
390 |
+
height=None, # Allow dynamic height
|
391 |
+
object_fit="contain" # Maintain aspect ratio
|
392 |
+
)
|
393 |
+
|
394 |
+
# Add event handler for image upload
|
395 |
+
input_image.upload(
|
396 |
+
fn=process_segmentation,
|
397 |
+
inputs=[input_image, dilation_slider],
|
398 |
+
outputs=[input_image, selected_parts]
|
399 |
+
)
|
400 |
+
|
401 |
+
# Add event handler for dilation changes
|
402 |
+
dilation_slider.change(
|
403 |
+
fn=update_dilation,
|
404 |
+
inputs=[input_image, selected_parts,dilation_slider],
|
405 |
+
outputs=input_image
|
406 |
+
)
|
407 |
+
|
408 |
+
# Add event handler for generation
|
409 |
+
generate_btn.click(
|
410 |
+
fn=process_image,
|
411 |
+
inputs=[prompt, input_image, selected_parts, dilation_slider],
|
412 |
+
outputs=gallery
|
413 |
+
)
|
414 |
+
|
415 |
+
# Add event handler for part selection changes
|
416 |
+
selected_parts.change(
|
417 |
+
fn=update_selected_parts,
|
418 |
+
inputs=[input_image, selected_parts, dilation_slider],
|
419 |
+
outputs=input_image
|
420 |
+
)
|
421 |
+
|
422 |
+
if __name__ == "__main__":
|
423 |
+
demo.launch(share=True)
|
424 |
+
|
425 |
+
|
426 |
+
|
427 |
+
|
428 |
+
|
429 |
+
|
colab_demo.py
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import autocast
|
3 |
+
from diffusers import StableDiffusionInpaintPipeline
|
4 |
+
import gradio as gr
|
5 |
+
import traceback
|
6 |
+
import base64
|
7 |
+
from io import BytesIO
|
8 |
+
import os
|
9 |
+
import PIL
|
10 |
+
import json
|
11 |
+
import requests
|
12 |
+
import logging
|
13 |
+
import time
|
14 |
+
import warnings
|
15 |
+
warnings.filterwarnings("ignore")
|
16 |
+
|
17 |
+
# Configure logging
|
18 |
+
logging.basicConfig(
|
19 |
+
level=logging.INFO,
|
20 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
21 |
+
)
|
22 |
+
logger = logging.getLogger('looks.studio')
|
23 |
+
|
24 |
+
# Model paths
|
25 |
+
SEGFORMER_MODEL = "mattmdjaga/segformer_b2_clothes"
|
26 |
+
STABLE_DIFFUSION_MODEL = "stabilityai/stable-diffusion-2-inpainting"
|
27 |
+
|
28 |
+
# Global variables for models
|
29 |
+
parser = None
|
30 |
+
model = None
|
31 |
+
inpainter = None
|
32 |
+
|
33 |
+
def get_device():
|
34 |
+
if torch.cuda.is_available():
|
35 |
+
device = "cuda"
|
36 |
+
logger.info("Using GPU")
|
37 |
+
else:
|
38 |
+
device = "cpu"
|
39 |
+
logger.info("Using CPU")
|
40 |
+
return device
|
41 |
+
|
42 |
+
def init():
|
43 |
+
global parser
|
44 |
+
global model
|
45 |
+
global inpainter
|
46 |
+
|
47 |
+
start_time = time.time()
|
48 |
+
logger.info("Starting application initialization")
|
49 |
+
|
50 |
+
try:
|
51 |
+
device = get_device()
|
52 |
+
|
53 |
+
# Initialize Segformer parser
|
54 |
+
logger.info("Initializing Segformer parser...")
|
55 |
+
from parser.segformer_parser import SegformerParser
|
56 |
+
parser = SegformerParser(SEGFORMER_MODEL)
|
57 |
+
|
58 |
+
# Initialize Stable Diffusion model
|
59 |
+
logger.info("Initializing Stable Diffusion model...")
|
60 |
+
model = StableDiffusionInpaintPipeline.from_pretrained(
|
61 |
+
STABLE_DIFFUSION_MODEL,
|
62 |
+
safety_checker=None,
|
63 |
+
revision="fp16" if device == "cuda" else None,
|
64 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32
|
65 |
+
).to(device)
|
66 |
+
|
67 |
+
# Initialize inpainter
|
68 |
+
logger.info("Initializing inpainter...")
|
69 |
+
inpainter = ClothingInpainter(model=model, parser=parser)
|
70 |
+
|
71 |
+
logger.info(f"Application initialized in {time.time() - start_time:.2f} seconds")
|
72 |
+
except Exception as e:
|
73 |
+
logger.error(f"Error initializing application: {str(e)}")
|
74 |
+
raise e
|
75 |
+
|
76 |
+
class ClothingInpainter:
|
77 |
+
def __init__(self, model_path=None, model=None, parser=None):
|
78 |
+
self.device = get_device()
|
79 |
+
|
80 |
+
if model_path is None and model is None:
|
81 |
+
raise ValueError('No model provided!')
|
82 |
+
if model_path is not None:
|
83 |
+
self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
|
84 |
+
model_path,
|
85 |
+
safety_checker=None,
|
86 |
+
revision="fp16" if self.device == "cuda" else None,
|
87 |
+
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32
|
88 |
+
).to(self.device)
|
89 |
+
else:
|
90 |
+
self.pipe = model
|
91 |
+
|
92 |
+
self.parser = parser
|
93 |
+
|
94 |
+
def make_square(self, im, min_size=256, fill_color=(0, 0, 0, 0)):
|
95 |
+
x, y = im.size
|
96 |
+
size = max(min_size, x, y)
|
97 |
+
new_im = PIL.Image.new('RGBA', (size, size), fill_color)
|
98 |
+
new_im.paste(im, (int((size - x) / 2), int((size - y) / 2)))
|
99 |
+
return new_im.convert('RGB')
|
100 |
+
|
101 |
+
def unmake_square(self, init_im, op_im, min_size=256, rs_size=512):
|
102 |
+
x, y = init_im.size
|
103 |
+
size = max(min_size, x, y)
|
104 |
+
factor = rs_size/size
|
105 |
+
return op_im.crop((int((size-x) * factor / 2), int((size-y) * factor / 2),\
|
106 |
+
int((size+x) * factor / 2), int((size+y) * factor / 2)))
|
107 |
+
|
108 |
+
def inpaint(self, prompt, init_image, parser=None) -> dict:
|
109 |
+
image = self.make_square(init_image).resize((512,512))
|
110 |
+
|
111 |
+
if self.parser is not None:
|
112 |
+
mask = self.parser.get_image_mask(image)
|
113 |
+
mask = mask.resize((512,512))
|
114 |
+
elif parser is not None:
|
115 |
+
mask = parser.get_image_mask(image)
|
116 |
+
mask = mask.resize((512,512))
|
117 |
+
else:
|
118 |
+
raise ValueError('Image Parser is Missing')
|
119 |
+
logger.info(f'[generated required mask(s) at {time.time()}]')
|
120 |
+
|
121 |
+
# Run the model
|
122 |
+
guidance_scale=7.5
|
123 |
+
num_samples = 3
|
124 |
+
with autocast("cuda"), torch.inference_mode():
|
125 |
+
images = self.pipe(
|
126 |
+
num_inference_steps = 50,
|
127 |
+
prompt=prompt['pos'],
|
128 |
+
image=image,
|
129 |
+
mask_image=mask,
|
130 |
+
guidance_scale=guidance_scale,
|
131 |
+
num_images_per_prompt=num_samples,
|
132 |
+
).images
|
133 |
+
|
134 |
+
images_output = []
|
135 |
+
for img in images:
|
136 |
+
ch = PIL.Image.composite(img,image, mask.convert('L'))
|
137 |
+
fin_img = self.unmake_square(init_image, ch)
|
138 |
+
images_output.append(fin_img)
|
139 |
+
|
140 |
+
return images_output
|
141 |
+
|
142 |
+
def process_image(prompt, image):
|
143 |
+
start_time = time.time()
|
144 |
+
logger.info(f"Processing new request - Prompt: {prompt}, Image size: {image.size if image else 'None'}")
|
145 |
+
|
146 |
+
try:
|
147 |
+
if image is None:
|
148 |
+
logger.error("No image provided")
|
149 |
+
raise gr.Error("Please upload an image")
|
150 |
+
if not prompt:
|
151 |
+
logger.error("No prompt provided")
|
152 |
+
raise gr.Error("Please enter a prompt")
|
153 |
+
|
154 |
+
prompt_dict = {'pos': prompt}
|
155 |
+
logger.info("Starting inpainting process")
|
156 |
+
images = inpainter.inpaint(prompt_dict, image)
|
157 |
+
|
158 |
+
if not images:
|
159 |
+
logger.error("Inpainting failed to produce results")
|
160 |
+
raise gr.Error("Failed to generate images. Please try again.")
|
161 |
+
|
162 |
+
logger.info(f"Request processed in {time.time() - start_time:.2f} seconds")
|
163 |
+
return images
|
164 |
+
except Exception as e:
|
165 |
+
logger.error(f"Error processing image: {str(e)}")
|
166 |
+
raise gr.Error(f"Error processing image: {str(e)}")
|
167 |
+
|
168 |
+
# Initialize the model
|
169 |
+
init()
|
170 |
+
|
171 |
+
# Create Gradio interface
|
172 |
+
with gr.Blocks(title="Looks.Studio - AI Clothing Inpainting") as demo:
|
173 |
+
gr.Markdown("# Looks.Studio - AI Clothing Inpainting")
|
174 |
+
gr.Markdown("Upload an image and describe the clothing you want to generate")
|
175 |
+
|
176 |
+
with gr.Row():
|
177 |
+
with gr.Column():
|
178 |
+
input_image = gr.Image(
|
179 |
+
type="pil",
|
180 |
+
label="Upload Image",
|
181 |
+
height=512
|
182 |
+
)
|
183 |
+
prompt = gr.Textbox(label="Describe the clothing you want to generate")
|
184 |
+
generate_btn = gr.Button("Generate")
|
185 |
+
|
186 |
+
with gr.Column():
|
187 |
+
gallery = gr.Gallery(
|
188 |
+
label="Generated Images",
|
189 |
+
show_label=False,
|
190 |
+
columns=2,
|
191 |
+
height=512
|
192 |
+
)
|
193 |
+
|
194 |
+
generate_btn.click(
|
195 |
+
fn=process_image,
|
196 |
+
inputs=[prompt, input_image],
|
197 |
+
outputs=gallery
|
198 |
+
)
|
199 |
+
|
200 |
+
if __name__ == "__main__":
|
201 |
+
demo.launch(share=True)
|
configs/configs.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"models": {
|
3 |
+
"schp": {
|
4 |
+
"download_url": "https://storage.googleapis.com/platform-ai-prod/looks_studio_data/schp_checkpoint.pth",
|
5 |
+
"path": "checkpoints/schp.pth"
|
6 |
+
},
|
7 |
+
"u2net": {
|
8 |
+
"download_url": "https://storage.googleapis.com/platform-ai-prod/looks_studio_data/cloth_segm_u2net_latest.pth",
|
9 |
+
"path": "checkpoints/cloth_segm_u2net_latest.pth"
|
10 |
+
},
|
11 |
+
"realesrgan": {
|
12 |
+
"download_url": "https://storage.googleapis.com/platform-ai-prod/looks_studio_data/RealESRGAN_x4plus.pth",
|
13 |
+
"path": "checkpoints/realesrgan_x4plus.pth"
|
14 |
+
},
|
15 |
+
"diffuser": {
|
16 |
+
"download_url": "https://storage.googleapis.com/platform-ai-prod/looks_studio_data/diffusers/stable_diffusion_2_checkpoint.tar",
|
17 |
+
"path": "checkpoints/stable_diffusion_2_inpainting"
|
18 |
+
}
|
19 |
+
}
|
20 |
+
}
|
download.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# In this file, we define download_model
|
2 |
+
# It runs during container build time to get model weights built into the container
|
3 |
+
|
4 |
+
import os
|
5 |
+
import wget
|
6 |
+
import json
|
7 |
+
import tarfile
|
8 |
+
import tempfile
|
9 |
+
|
10 |
+
def download_models(config):
|
11 |
+
# Download parser checkpoint
|
12 |
+
# wget.download(config['schp']['download_url'],
|
13 |
+
# os.path.join(os.path.dirname(__file__), config['schp']['path']))
|
14 |
+
wget.download(config['u2net']['download_url'],
|
15 |
+
os.path.join(os.path.dirname(__file__), config['u2net']['path']))
|
16 |
+
|
17 |
+
# Download Super resolution model
|
18 |
+
wget.download(config['realesrgan']['download_url'],
|
19 |
+
os.path.join(os.path.dirname(__file__), config['realesrgan']['path']))
|
20 |
+
|
21 |
+
# Download diffuser model checkpoint
|
22 |
+
_, local_file_name = tempfile.mkstemp()
|
23 |
+
local_file_name += '.tar'
|
24 |
+
wget.download(config['diffuser']['download_url'], local_file_name)
|
25 |
+
tar_file = tarfile.open(local_file_name)
|
26 |
+
tar_file.extractall(os.path.join(os.path.dirname(__file__),'checkpoints/'))
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
config_file = "configs/configs.json"
|
30 |
+
config_file = os.path.join(os.path.dirname(__file__), config_file)
|
31 |
+
|
32 |
+
with open(config_file) as fin:
|
33 |
+
config = json.load(fin)
|
34 |
+
download_models(config['models'])
|
download_models.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
import logging
|
4 |
+
|
5 |
+
logger = logging.getLogger(__name__)
|
6 |
+
|
7 |
+
def download_models():
|
8 |
+
"""Download required models for the application"""
|
9 |
+
start_time = time.time()
|
10 |
+
logger.info("Starting model download process")
|
11 |
+
|
12 |
+
try:
|
13 |
+
# Create models directory if it doesn't exist
|
14 |
+
os.makedirs("models", exist_ok=True)
|
15 |
+
|
16 |
+
logger.info(f"Model setup completed in {time.time() - start_time:.2f} seconds")
|
17 |
+
except Exception as e:
|
18 |
+
logger.error(f"Error in model setup: {str(e)}")
|
19 |
+
raise
|
20 |
+
|
21 |
+
if __name__ == "__main__":
|
22 |
+
download_models()
|
parser/__init__.py
ADDED
File without changes
|
parser/schp_masker.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import sys
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch.utils import data
|
7 |
+
from torch.utils.data import DataLoader
|
8 |
+
import torchvision.transforms as transforms
|
9 |
+
|
10 |
+
from PIL import Image
|
11 |
+
from collections import OrderedDict
|
12 |
+
|
13 |
+
sys.path.insert(1, './schp')
|
14 |
+
from utils.transforms import get_affine_transform
|
15 |
+
import networks
|
16 |
+
from utils.transforms import transform_logits
|
17 |
+
|
18 |
+
class PILImageDataset(data.Dataset):
|
19 |
+
def __init__(self, img_lst=[], input_size=[512, 512], transform=None):
|
20 |
+
self.img_lst = img_lst
|
21 |
+
self.input_size = input_size
|
22 |
+
self.transform = transform
|
23 |
+
self.aspect_ratio = input_size[1] * 1.0 / input_size[0]
|
24 |
+
self.input_size = np.asarray(input_size)
|
25 |
+
|
26 |
+
def __len__(self):
|
27 |
+
return len(self.img_lst)
|
28 |
+
|
29 |
+
def _box2cs(self, box):
|
30 |
+
x, y, w, h = box[:4]
|
31 |
+
return self._xywh2cs(x, y, w, h)
|
32 |
+
|
33 |
+
def _xywh2cs(self, x, y, w, h):
|
34 |
+
center = np.zeros((2), dtype=np.float32)
|
35 |
+
center[0] = x + w * 0.5
|
36 |
+
center[1] = y + h * 0.5
|
37 |
+
if w > self.aspect_ratio * h:
|
38 |
+
h = w * 1.0 / self.aspect_ratio
|
39 |
+
elif w < self.aspect_ratio * h:
|
40 |
+
w = h * self.aspect_ratio
|
41 |
+
scale = np.array([w, h], dtype=np.float32)
|
42 |
+
return center, scale
|
43 |
+
|
44 |
+
def __getitem__(self, index):
|
45 |
+
img = np.array(self.img_lst[index])[:,:,::-1]
|
46 |
+
h, w, _ = img.shape
|
47 |
+
|
48 |
+
# Get person center and scale
|
49 |
+
person_center, s = self._box2cs([0, 0, w - 1, h - 1])
|
50 |
+
r = 0
|
51 |
+
trans = get_affine_transform(person_center, s, r, self.input_size)
|
52 |
+
input = cv2.warpAffine(
|
53 |
+
img,
|
54 |
+
trans,
|
55 |
+
(int(self.input_size[1]), int(self.input_size[0])),
|
56 |
+
flags=cv2.INTER_LINEAR,
|
57 |
+
borderMode=cv2.BORDER_CONSTANT,
|
58 |
+
borderValue=(0, 0, 0))
|
59 |
+
|
60 |
+
input = self.transform(input)
|
61 |
+
meta = {
|
62 |
+
'center': person_center,
|
63 |
+
'height': h,
|
64 |
+
'width': w,
|
65 |
+
'scale': s,
|
66 |
+
'rotation': r
|
67 |
+
}
|
68 |
+
|
69 |
+
return input, meta
|
70 |
+
|
71 |
+
PALLETE_DICT = {
|
72 |
+
'Background': [],
|
73 |
+
'Face': [],
|
74 |
+
'Upper-clothes':[],
|
75 |
+
'Dress':[],
|
76 |
+
'Coat':[],
|
77 |
+
'Soaks':[],
|
78 |
+
'Pants':[],
|
79 |
+
'Jumpsuits':[],
|
80 |
+
'Scarf':[],
|
81 |
+
'Skirt':[],
|
82 |
+
'Arm':[],
|
83 |
+
'Leg':[],
|
84 |
+
'Shoe':[]
|
85 |
+
}
|
86 |
+
|
87 |
+
val_list = [[0],[1,4,13],[5],[6],[7],[8],[9],[10],[11],[12],[14,15],[16,17],[18,19]]
|
88 |
+
for c,j in enumerate(PALLETE_DICT.keys()):
|
89 |
+
val = val_list[c]
|
90 |
+
pallete = []
|
91 |
+
for i in range(60):
|
92 |
+
if len(val) == 1:
|
93 |
+
if (i >= (val[0]*3)) & (i < ((val[0]+1)*3)):
|
94 |
+
pallete.append(255)
|
95 |
+
else:
|
96 |
+
pallete.append(0)
|
97 |
+
if len(val) == 2:
|
98 |
+
if (i >= (val[0]*3)) & (i < ((val[0]+1)*3)) or (i >= (val[1]*3)) & (i < ((val[1]+1)*3)):
|
99 |
+
pallete.append(255)
|
100 |
+
else:
|
101 |
+
pallete.append(0)
|
102 |
+
if len(val) == 3:
|
103 |
+
if (i >= (val[0]*3)) & (i < ((val[0]+1)*3)) or (i >= (val[1]*3)) & (i < ((val[1]+1)*3)) or (i >= (val[2]*3)) & (i < ((val[2]+1)*3)):
|
104 |
+
pallete.append(255)
|
105 |
+
else:
|
106 |
+
pallete.append(0)
|
107 |
+
|
108 |
+
PALLETE_DICT[j] = pallete
|
109 |
+
|
110 |
+
|
111 |
+
DATASET_SETTINGS = {
|
112 |
+
'lip': {
|
113 |
+
'input_size': [473, 473],
|
114 |
+
'num_classes': 20,
|
115 |
+
'label': ['Background', 'Hat', 'Hair', 'Glove', 'Sunglasses', 'Upper-clothes', 'Dress', 'Coat',
|
116 |
+
'Socks', 'Pants', 'Jumpsuits', 'Scarf', 'Skirt', 'Face', 'Left-arm', 'Right-arm',
|
117 |
+
'Left-leg', 'Right-leg', 'Left-shoe', 'Right-shoe']
|
118 |
+
},
|
119 |
+
'atr': {
|
120 |
+
'input_size': [512, 512],
|
121 |
+
'num_classes': 18,
|
122 |
+
'label': ['Background', 'Hat', 'Hair', 'Sunglasses', 'Upper-clothes', 'Skirt', 'Pants', 'Dress', 'Belt',
|
123 |
+
'Left-shoe', 'Right-shoe', 'Face', 'Left-leg', 'Right-leg', 'Left-arm', 'Right-arm', 'Bag', 'Scarf']
|
124 |
+
},
|
125 |
+
'pascal': {
|
126 |
+
'input_size': [512, 512],
|
127 |
+
'num_classes': 7,
|
128 |
+
'label': ['Background', 'Head', 'Torso', 'Upper Arms', 'Lower Arms', 'Upper Legs', 'Lower Legs'],
|
129 |
+
}
|
130 |
+
}
|
131 |
+
|
132 |
+
|
133 |
+
|
134 |
+
class SCHPParser:
|
135 |
+
def __init__(self, checkpoint_path, dataset_settings):
|
136 |
+
self.cp_path = checkpoint_path
|
137 |
+
self.ops = []
|
138 |
+
self.num_classes = dataset_settings['lip']['num_classes']
|
139 |
+
self.input_size = dataset_settings['lip']['input_size']
|
140 |
+
self.label = dataset_settings['lip']['label']
|
141 |
+
self.pallete_dict = PALLETE_DICT
|
142 |
+
|
143 |
+
|
144 |
+
self.img_transforms = transforms.Compose([
|
145 |
+
transforms.ToTensor(),
|
146 |
+
transforms.Normalize(mean=[0.406, 0.456, 0.485], std=[0.225, 0.224, 0.229])
|
147 |
+
])
|
148 |
+
|
149 |
+
self.model = self.load_model()
|
150 |
+
|
151 |
+
|
152 |
+
def load_model(self):
|
153 |
+
model = networks.init_model('resnet101', num_classes=self.num_classes, pretrained=None)
|
154 |
+
state_dict = torch.load(self.cp_path)['state_dict']
|
155 |
+
new_state_dict = OrderedDict()
|
156 |
+
for k, v in state_dict.items():
|
157 |
+
name = k[7:] # remove `module.`
|
158 |
+
new_state_dict[name] = v
|
159 |
+
model.load_state_dict(new_state_dict)
|
160 |
+
model.cuda()
|
161 |
+
model.eval()
|
162 |
+
return model
|
163 |
+
|
164 |
+
def create_dataloader(self, img_lst):
|
165 |
+
dataset = PILImageDataset(img_lst, input_size=self.input_size, transform=self.img_transforms)
|
166 |
+
# dataset = SimpleFolderDataset('inputs',input_size, transform)
|
167 |
+
dataloader = DataLoader(dataset)
|
168 |
+
return dataloader
|
169 |
+
|
170 |
+
def get_image_masks(self, img_lst):
|
171 |
+
print("Evaluating total class number {} with {}".format(self.num_classes, self.label))
|
172 |
+
dataloader = self.create_dataloader(img_lst)
|
173 |
+
with torch.no_grad():
|
174 |
+
for batch in dataloader:
|
175 |
+
op_dict = {}
|
176 |
+
image, meta = batch
|
177 |
+
c = meta['center'].numpy()[0]
|
178 |
+
s = meta['scale'].numpy()[0]
|
179 |
+
w = meta['width'].numpy()[0]
|
180 |
+
h = meta['height'].numpy()[0]
|
181 |
+
|
182 |
+
output = self.model(image.cuda())
|
183 |
+
upsample = torch.nn.Upsample(size=self.input_size, mode='bilinear', align_corners=True)
|
184 |
+
upsample_output = upsample(output[0][-1][0].unsqueeze(0))
|
185 |
+
upsample_output = upsample_output.squeeze()
|
186 |
+
upsample_output = upsample_output.permute(1, 2, 0) # CHW -> HWC
|
187 |
+
|
188 |
+
logits_result = transform_logits(upsample_output.data.cpu().numpy(), c, s, w, h, input_size=self.input_size)
|
189 |
+
parsing_result = np.argmax(logits_result, axis=2)
|
190 |
+
output_img = Image.fromarray(np.asarray(parsing_result, dtype=np.uint8))
|
191 |
+
for loc, key in enumerate(self.pallete_dict.keys()):
|
192 |
+
output_img.putpalette(self.pallete_dict[key])
|
193 |
+
op_dict.update({
|
194 |
+
key: output_img.convert('L')
|
195 |
+
})
|
196 |
+
self.ops.append(op_dict)
|
197 |
+
return self.ops
|
198 |
+
|
199 |
+
|
200 |
+
|
parser/segformer_parser.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from PIL import Image
|
4 |
+
from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import logging
|
7 |
+
import time
|
8 |
+
from typing import Tuple, Optional
|
9 |
+
|
10 |
+
logger = logging.getLogger('looks.studio.segformer')
|
11 |
+
|
12 |
+
class SegformerParser:
|
13 |
+
def __init__(self, model_path="mattmdjaga/segformer_b2_clothes"):
|
14 |
+
self.start_time = time.time()
|
15 |
+
logger.info(f"Initializing SegformerParser with model: {model_path}")
|
16 |
+
|
17 |
+
try:
|
18 |
+
self.processor = SegformerImageProcessor.from_pretrained(model_path)
|
19 |
+
self.model = AutoModelForSemanticSegmentation.from_pretrained(model_path)
|
20 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
21 |
+
logger.info(f"Using device: {self.device}")
|
22 |
+
self.model.to(self.device)
|
23 |
+
|
24 |
+
# Define clothing-related labels
|
25 |
+
self.clothing_labels = {
|
26 |
+
4: "upper-clothes",
|
27 |
+
5: "skirt",
|
28 |
+
6: "pants",
|
29 |
+
7: "dress",
|
30 |
+
8: "belt",
|
31 |
+
9: "left-shoe",
|
32 |
+
10: "right-shoe",
|
33 |
+
14: "left-arm",
|
34 |
+
15: "right-arm",
|
35 |
+
17: "scarf"
|
36 |
+
}
|
37 |
+
|
38 |
+
logger.info(f"SegformerParser initialized in {time.time() - self.start_time:.2f} seconds")
|
39 |
+
except Exception as e:
|
40 |
+
logger.error(f"Failed to initialize SegformerParser: {str(e)}")
|
41 |
+
raise
|
42 |
+
|
43 |
+
def _resize_image(self, image: Image.Image, max_size: int = 1024) -> Tuple[Image.Image, float]:
|
44 |
+
"""Resize image while maintaining aspect ratio if it exceeds max_size"""
|
45 |
+
width, height = image.size
|
46 |
+
scale = 1.0
|
47 |
+
|
48 |
+
if width > max_size or height > max_size:
|
49 |
+
scale = max_size / max(width, height)
|
50 |
+
new_width = int(width * scale)
|
51 |
+
new_height = int(height * scale)
|
52 |
+
logger.info(f"Resizing image from {width}x{height} to {new_width}x{new_height}")
|
53 |
+
image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
54 |
+
|
55 |
+
return image, scale
|
56 |
+
|
57 |
+
def _validate_image(self, image: Image.Image) -> bool:
|
58 |
+
"""Validate input image"""
|
59 |
+
if not isinstance(image, Image.Image):
|
60 |
+
logger.error("Input is not a PIL Image")
|
61 |
+
return False
|
62 |
+
|
63 |
+
if image.mode not in ['RGB', 'RGBA']:
|
64 |
+
logger.error(f"Unsupported image mode: {image.mode}")
|
65 |
+
return False
|
66 |
+
|
67 |
+
width, height = image.size
|
68 |
+
if width < 64 or height < 64:
|
69 |
+
logger.error(f"Image too small: {width}x{height}")
|
70 |
+
return False
|
71 |
+
|
72 |
+
if width > 4096 or height > 4096:
|
73 |
+
logger.error(f"Image too large: {width}x{height}")
|
74 |
+
return False
|
75 |
+
|
76 |
+
return True
|
77 |
+
|
78 |
+
def get_image_mask(self, image: Image.Image) -> Optional[Image.Image]:
|
79 |
+
"""Generate segmentation mask for clothing"""
|
80 |
+
start_time = time.time()
|
81 |
+
logger.info(f"Starting segmentation for image size: {image.size}")
|
82 |
+
|
83 |
+
try:
|
84 |
+
# Validate input image
|
85 |
+
if not self._validate_image(image):
|
86 |
+
return None
|
87 |
+
|
88 |
+
# Convert RGBA to RGB if necessary
|
89 |
+
if image.mode == 'RGBA':
|
90 |
+
logger.info("Converting RGBA to RGB")
|
91 |
+
image = image.convert('RGB')
|
92 |
+
|
93 |
+
# Resize image if too large
|
94 |
+
image, scale = self._resize_image(image)
|
95 |
+
|
96 |
+
# Process the image
|
97 |
+
logger.info("Processing image with Segformer")
|
98 |
+
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
|
99 |
+
|
100 |
+
# Get predictions
|
101 |
+
with torch.no_grad():
|
102 |
+
outputs = self.model(**inputs)
|
103 |
+
logits = outputs.logits.cpu()
|
104 |
+
|
105 |
+
# Upsample logits to original image size
|
106 |
+
upsampled_logits = F.interpolate(
|
107 |
+
logits,
|
108 |
+
size=image.size[::-1],
|
109 |
+
mode="bilinear",
|
110 |
+
align_corners=False,
|
111 |
+
)
|
112 |
+
|
113 |
+
# Get the segmentation mask
|
114 |
+
pred_seg = upsampled_logits.argmax(dim=1)[0]
|
115 |
+
|
116 |
+
# Create a binary mask for clothing
|
117 |
+
mask = torch.zeros_like(pred_seg)
|
118 |
+
for label_id in self.clothing_labels.keys():
|
119 |
+
mask[pred_seg == label_id] = 255
|
120 |
+
|
121 |
+
# Convert to PIL Image
|
122 |
+
mask = Image.fromarray(mask.numpy().astype(np.uint8))
|
123 |
+
|
124 |
+
# Resize mask back to original size if needed
|
125 |
+
if scale != 1.0:
|
126 |
+
original_size = (int(image.size[0] / scale), int(image.size[1] / scale))
|
127 |
+
logger.info(f"Resizing mask back to original size: {original_size}")
|
128 |
+
mask = mask.resize(original_size, Image.Resampling.NEAREST)
|
129 |
+
|
130 |
+
logger.info(f"Segmentation completed in {time.time() - start_time:.2f} seconds")
|
131 |
+
return mask
|
132 |
+
|
133 |
+
except Exception as e:
|
134 |
+
logger.error(f"Error during segmentation: {str(e)}")
|
135 |
+
return None
|
136 |
+
|
137 |
+
def get_all_masks(self, image: Image.Image) -> dict:
|
138 |
+
"""Return a dict of binary masks for each clothing part label."""
|
139 |
+
start_time = time.time()
|
140 |
+
logger.info(f"Starting per-part segmentation for image size: {image.size}")
|
141 |
+
masks = {}
|
142 |
+
try:
|
143 |
+
# Validate input image
|
144 |
+
if not self._validate_image(image):
|
145 |
+
return masks
|
146 |
+
|
147 |
+
# Convert RGBA to RGB if necessary
|
148 |
+
if image.mode == 'RGBA':
|
149 |
+
logger.info("Converting RGBA to RGB")
|
150 |
+
image = image.convert('RGB')
|
151 |
+
|
152 |
+
# Resize image if too large
|
153 |
+
image, scale = self._resize_image(image)
|
154 |
+
|
155 |
+
# Process the image
|
156 |
+
logger.info("Processing image with Segformer for all masks")
|
157 |
+
inputs = self.processor(images=image, return_tensors="pt").to(self.device)
|
158 |
+
|
159 |
+
# Get predictions
|
160 |
+
with torch.no_grad():
|
161 |
+
outputs = self.model(**inputs)
|
162 |
+
logits = outputs.logits.cpu()
|
163 |
+
upsampled_logits = F.interpolate(
|
164 |
+
logits,
|
165 |
+
size=image.size[::-1],
|
166 |
+
mode="bilinear",
|
167 |
+
align_corners=False,
|
168 |
+
)
|
169 |
+
pred_seg = upsampled_logits.argmax(dim=1)[0]
|
170 |
+
|
171 |
+
# For each clothing label, create a binary mask
|
172 |
+
for label_id, part_name in self.clothing_labels.items():
|
173 |
+
mask = (pred_seg == label_id).numpy().astype(np.uint8) * 255
|
174 |
+
mask_img = Image.fromarray(mask)
|
175 |
+
# Resize mask back to original size if needed
|
176 |
+
if scale != 1.0:
|
177 |
+
original_size = (int(image.size[0] / scale), int(image.size[1] / scale))
|
178 |
+
mask_img = mask_img.resize(original_size, Image.Resampling.NEAREST)
|
179 |
+
masks[part_name] = mask_img
|
180 |
+
|
181 |
+
logger.info(f"Per-part segmentation completed in {time.time() - start_time:.2f} seconds")
|
182 |
+
return masks
|
183 |
+
except Exception as e:
|
184 |
+
logger.error(f"Error during per-part segmentation: {str(e)}")
|
185 |
+
return masks
|
parser/u2net_parser.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
# from tqdm import tqdm
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import sys
|
6 |
+
|
7 |
+
import torch
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
|
11 |
+
from .u2net_cloth_seg.data.base_dataset import Normalize_image
|
12 |
+
from .u2net_cloth_seg.utils.saving_utils import load_checkpoint_mgpu
|
13 |
+
|
14 |
+
from .u2net_cloth_seg.networks import U2NET
|
15 |
+
|
16 |
+
class U2NETParser:
|
17 |
+
def __init__(self, checkpoint_path):
|
18 |
+
self.cp_path = checkpoint_path
|
19 |
+
self.img_transforms = transforms.Compose([
|
20 |
+
transforms.ToTensor(),
|
21 |
+
Normalize_image(0.5, 0.5)
|
22 |
+
])
|
23 |
+
self.model = self.load_model()
|
24 |
+
|
25 |
+
|
26 |
+
def load_model(self):
|
27 |
+
model = U2NET(in_ch=3, out_ch=4)
|
28 |
+
model = load_checkpoint_mgpu(model, self.cp_path)
|
29 |
+
model = model.to("cuda")
|
30 |
+
model = model.eval()
|
31 |
+
return model
|
32 |
+
|
33 |
+
def get_image_mask(self, img):
|
34 |
+
# print("Evaluating total class number {} with {}".format(self.num_classes, self.label))
|
35 |
+
img_size = img.size
|
36 |
+
img = img.resize((768, 768), Image.BICUBIC)
|
37 |
+
image_tensor = self.img_transforms(img)
|
38 |
+
image_tensor = torch.unsqueeze(image_tensor, 0)
|
39 |
+
|
40 |
+
with torch.no_grad():
|
41 |
+
output_tensor = self.model(image_tensor.to("cuda"))
|
42 |
+
|
43 |
+
output_tensor = F.log_softmax(output_tensor[0], dim=1)
|
44 |
+
output_tensor = torch.max(output_tensor, dim=1, keepdim=True)[1]
|
45 |
+
output_tensor = torch.squeeze(output_tensor, dim=0)
|
46 |
+
output_tensor = torch.squeeze(output_tensor, dim=0)
|
47 |
+
output_arr = output_tensor.cpu().numpy()
|
48 |
+
|
49 |
+
output_arr[output_arr != 1] = 0
|
50 |
+
output_arr[output_arr == 1] = 255
|
51 |
+
|
52 |
+
output_img = Image.fromarray(output_arr.astype('uint8'), mode='L')
|
53 |
+
output_img = output_img.resize(img_size, Image.BICUBIC)
|
54 |
+
|
55 |
+
return output_img
|
requirements.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
sanic>=25.3.0
|
2 |
+
git+https://github.com/huggingface/diffusers.git#egg=diffusers
|
3 |
+
transformers>=4.30.0
|
4 |
+
scipy>=1.11.0
|
5 |
+
opencv-python>=4.8.0
|
6 |
+
wget
|
7 |
+
# ninja
|
8 |
+
accelerate>=0.24.0
|
9 |
+
basicsr>=1.4.2
|
10 |
+
ftfy>=6.1.1
|
11 |
+
# bitsandbytes
|
12 |
+
gradio>=3.50.0
|
13 |
+
# natsort
|
14 |
+
# https://github.com/metrolobo/xformers_wheels/releases/download/1d31a3ac_various_6/xformers-0.0.14.dev0-cp37-cp37m-linux_x86_64.whl
|
15 |
+
torch>=2.0.0
|
16 |
+
diffusers>=0.19.0
|
17 |
+
Pillow>=9.0.0
|
18 |
+
requests>=2.28.0
|
19 |
+
numpy>=1.24.0
|
20 |
+
huggingface_hub>=0.16.0
|
21 |
+
matplotlib>=3.7.0 # For visualization
|
server.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Do not edit if deploying to Banana Serverless
|
2 |
+
# This file is boilerplate for the http server, and follows a strict interface.
|
3 |
+
|
4 |
+
# Instead, edit the init() and inference() functions in app.py
|
5 |
+
|
6 |
+
from sanic import Sanic, response
|
7 |
+
import subprocess
|
8 |
+
import app as user_src
|
9 |
+
|
10 |
+
# We do the model load-to-GPU step on server startup
|
11 |
+
# so the model object is available globally for reuse
|
12 |
+
user_src.init()
|
13 |
+
|
14 |
+
# Create the http server app
|
15 |
+
server = Sanic("my_app")
|
16 |
+
|
17 |
+
# Healthchecks verify that the environment is correct on Banana Serverless
|
18 |
+
@server.route('/healthcheck', methods=["GET"])
|
19 |
+
def healthcheck(request):
|
20 |
+
# dependency free way to check if GPU is visible
|
21 |
+
gpu = False
|
22 |
+
out = subprocess.run("nvidia-smi", shell=True)
|
23 |
+
if out.returncode == 0: # success state on shell command
|
24 |
+
gpu = True
|
25 |
+
|
26 |
+
return response.json({"state": "healthy", "gpu": gpu})
|
27 |
+
|
28 |
+
# Inference POST handler at '/' is called for every http call from Banana
|
29 |
+
@server.route('/', methods=["POST"])
|
30 |
+
def inference(request):
|
31 |
+
try:
|
32 |
+
model_inputs = response.json.loads(request.json)
|
33 |
+
except:
|
34 |
+
model_inputs = request.json
|
35 |
+
|
36 |
+
output = user_src.inference(model_inputs)
|
37 |
+
|
38 |
+
return response.json(output)
|
39 |
+
|
40 |
+
|
41 |
+
if __name__ == '__main__':
|
42 |
+
server.run(host='0.0.0.0', port=8000, workers=1)
|
upscaler/__init__.py
ADDED
File without changes
|
upscaler/realesrgan_upscaler.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from basicsr.archs.rrdbnet_arch import RRDBNet
|
2 |
+
from .real_esrgan.realesrgan import RealESRGANer
|
3 |
+
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
|
9 |
+
class RealESRGAN:
|
10 |
+
def __init__(self, checkpoint_path):
|
11 |
+
|
12 |
+
self.netscale = 4
|
13 |
+
|
14 |
+
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
|
15 |
+
|
16 |
+
self.upsampler = RealESRGANer(
|
17 |
+
scale=self.netscale,
|
18 |
+
model_path=checkpoint_path,
|
19 |
+
dni_weight=None,
|
20 |
+
model=model,
|
21 |
+
tile=0,
|
22 |
+
tile_pad=10,
|
23 |
+
pre_pad=0,
|
24 |
+
half=True)
|
25 |
+
|
26 |
+
def upscale(self, pil_image, scale_factor=3):
|
27 |
+
cv2_img = cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGB2BGR)
|
28 |
+
op, _ = self.upsampler.enhance(cv2_img, outscale=scale_factor)
|
29 |
+
pil_image_fin = Image.fromarray(cv2.cvtColor(op, cv2.COLOR_BGR2RGB))
|
30 |
+
return pil_image_fin
|