aijack commited on
Commit
cbf51d9
·
1 Parent(s): 5f5dc50

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -237
app.py DELETED
@@ -1,237 +0,0 @@
1
- import os
2
-
3
- import random
4
- import torch
5
- import gradio as gr
6
-
7
- from e4e.models.psp import pSp
8
- from util import *
9
- from huggingface_hub import hf_hub_download
10
-
11
- import tempfile
12
- from argparse import Namespace
13
- import shutil
14
-
15
- import dlib
16
- import numpy as np
17
- import torchvision.transforms as transforms
18
- from torchvision import utils
19
-
20
- from model.sg2_model import Generator
21
- from generate_videos import project_code_by_edit_name
22
- import urllib.request
23
- import clip
24
-
25
- # Fetch image for analysis
26
- img_url = "http://claireye.com.tw/img/230212a.jpg"
27
- urllib.request.urlretrieve(img_url, "pose.jpg")
28
- model_dir = "models"
29
- os.makedirs(model_dir, exist_ok=True)
30
-
31
- model_repos = {
32
- "e4e": ("aijack/e4e", "e4e.pt"),
33
- "dlib": ("aijack/jojogan", "face_landmarks.dat"),
34
- "base": ("aijack/stylegan2", "stylegan2.pt"),
35
- "sketch": ("aijack/sketch", "sketch.pt"),
36
- "jojo": ("aijack/jojo", "jojo.pt"),
37
- "art": ("aijack/art", "art.pt"),
38
- "arcane": ("aijack/arcane", "arcane.pt")
39
-
40
- }
41
-
42
- interface_gan_map = {"None": None, "Masculine": ("gender", 1.0), "Feminine": ("gender", -1.0),
43
- "Smiling": ("smile", 1.0),
44
- "Frowning": ("smile", -1.0), "Young": ("age", -1.0), "Old": ("age", 1.0),
45
- "Long Hair": ("hair_length", -1.0), "Short Hair": ("hair_length", 1.0)}
46
-
47
-
48
- def get_models():
49
- os.makedirs(model_dir, exist_ok=True)
50
-
51
- model_paths = {}
52
-
53
- for model_name, repo_details in model_repos.items():
54
- download_path = hf_hub_download(repo_id=repo_details[0], filename=repo_details[1])
55
- model_paths[model_name] = download_path
56
-
57
- return model_paths
58
-
59
-
60
- model_paths = get_models()
61
-
62
-
63
- class ImageEditor(object):
64
- def __init__(self):
65
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
66
-
67
- latent_size = 512
68
- n_mlp = 8
69
- channel_mult = 2
70
- model_size = 1024
71
-
72
- self.generators = {}
73
-
74
- self.model_list = [name for name in model_paths.keys() if name not in ["e4e", "dlib"]]
75
-
76
- for model in self.model_list:
77
- g_ema = Generator(
78
- model_size, latent_size, n_mlp, channel_multiplier=channel_mult
79
- ).to(self.device)
80
-
81
- checkpoint = torch.load(model_paths[model], map_location=self.device)
82
-
83
- g_ema.load_state_dict(checkpoint["g_ema"], strict=False)
84
-
85
- self.generators[model] = g_ema
86
-
87
- self.experiment_args = {"model_path": model_paths["e4e"]}
88
- self.experiment_args["transform"] = transforms.Compose(
89
- [
90
- transforms.Resize((256, 256)),
91
- transforms.ToTensor(),
92
- transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
93
- ]
94
- )
95
- self.resize_dims = (256, 256)
96
-
97
- model_path = self.experiment_args["model_path"]
98
-
99
- ckpt = torch.load(model_path, map_location="cuda:0" if torch.cuda.is_available() else "cpu")
100
- opts = ckpt["opts"]
101
-
102
- opts["checkpoint_path"] = model_path
103
- opts = Namespace(**opts)
104
-
105
- self.e4e_net = pSp(opts, self.device)
106
- self.e4e_net.eval()
107
-
108
- self.shape_predictor = dlib.shape_predictor(
109
- model_paths["dlib"]
110
- )
111
-
112
-
113
- self.clip_model, _ = clip.load("ViT-B/32", device=self.device)
114
-
115
- print("setup complete")
116
-
117
- def get_style_list(self):
118
- style_list = []
119
-
120
- for key in self.generators:
121
- style_list.append(key)
122
-
123
- return style_list
124
-
125
- def invert_image(self, input_image):
126
- input_image = self.run_alignment(str(input_image))
127
-
128
- input_image = input_image.resize(self.resize_dims)
129
-
130
- img_transforms = self.experiment_args["transform"]
131
- transformed_image = img_transforms(input_image)
132
-
133
- with torch.no_grad():
134
- images, latents = self.run_on_batch(transformed_image.unsqueeze(0))
135
- result_image, latent = images[0], latents[0]
136
-
137
- inverted_latent = latent.unsqueeze(0).unsqueeze(1)
138
-
139
- return inverted_latent
140
-
141
- def get_generators_for_styles(self, output_styles, loop_styles=False):
142
-
143
- if "base" in output_styles: # always start with base if chosen
144
- output_styles.insert(0, output_styles.pop(output_styles.index("base")))
145
- if loop_styles:
146
- output_styles.append(output_styles[0])
147
-
148
- return [self.generators[style] for style in output_styles]
149
-
150
-
151
-
152
- def get_target_latent(self, source_latent, alter, generators):
153
- np_source_latent = source_latent.squeeze(0).cpu().detach().numpy()
154
- if alter == "None":
155
- return random.choice([source_latent.squeeze(0),] * max((len(generators) - 1), 1))
156
- edit = interface_gan_map[alter]
157
- projected_code_np = project_code_by_edit_name(np_source_latent, edit[0], edit[1])
158
- return torch.from_numpy(projected_code_np).float().to(self.device)
159
-
160
- def edit_image(self, input, output_styles, edit_choices):
161
- return self.predict(input, output_styles, edit_choices=edit_choices)
162
-
163
- def predict(
164
- self,
165
- input, # Input image path
166
- output_styles, # Style checkbox options.
167
- loop_styles=False, # Loop back to the initial style
168
- edit_choices=None, # Optional dictionary with edit choice arguments
169
- ):
170
-
171
- if edit_choices is None:
172
- edit_choices = {"edit_type": "None"}
173
-
174
- # @title Align image
175
- out_dir = tempfile.mkdtemp()
176
-
177
- inverted_latent = self.invert_image(input)
178
- generators = self.get_generators_for_styles(output_styles, loop_styles)
179
- output_paths = []
180
-
181
- with torch.no_grad():
182
- for g_ema in generators:
183
- latent_for_gen = self.get_target_latent(inverted_latent, edit_choices, generators)
184
-
185
- img, _ = g_ema([latent_for_gen], input_is_latent=True, truncation=1, randomize_noise=False)
186
-
187
- output_path = os.path.join(out_dir, f"out_{len(output_paths)}.jpg")
188
- utils.save_image(img, output_path, nrow=1, normalize=True, range=(-1, 1))
189
-
190
- output_paths.append(output_path)
191
-
192
- return output_paths
193
-
194
-
195
- def run_alignment(self, image_path):
196
- aligned_image = align_face(filepath=image_path, predictor=self.shape_predictor)
197
- print("Aligned image has shape: {}".format(aligned_image.size))
198
- return aligned_image
199
-
200
- def run_on_batch(self, inputs):
201
- images, latents = self.e4e_net(
202
- inputs.to(self.device).float(), randomize_noise=False, return_latents=True
203
- )
204
- return images, latents
205
-
206
-
207
- editor = ImageEditor()
208
-
209
- blocks = gr.Blocks(theme="darkdefault")
210
-
211
- with blocks:
212
- gr.Markdown("<h1><center>Holiday Filters </center></h1>")
213
- gr.Markdown(
214
- "<div>Upload an image of your face, pick your desired output styles, pick any modifiers, and apply StyleGAN-based editing.</div>"
215
- )
216
- with gr.Row():
217
- with gr.Column():
218
- input_img = gr.Image(type="filepath", label="Input image")
219
- with gr.Column():
220
- style_choice = gr.CheckboxGroup(choices=editor.get_style_list(), value=editor.get_style_list(), type="value", label="Styles")
221
- alter = gr.Dropdown(
222
- choices=["None", "Masculine", "Feminine", "Smiling", "Frowning", "Young", "Old", "Short Hair",
223
- "Long Hair"], value="None", label="Additional Modifiers")
224
- img_button = gr.Button("Edit Image")
225
- with gr.Row():
226
- img_output = gr.Gallery(label="Output Images")
227
- img_output.style(grid=(3, 3, 4, 4, 6, 6))
228
-
229
- img_button.click(fn=editor.edit_image, inputs=[input_img, style_choice, alter], outputs=img_output)
230
- ex = gr.Examples(examples=[['pose.jpg', editor.get_style_list(), "Smiling"],['pose.jpg', editor.get_style_list(), "Long Hair"]], fn=editor.edit_image, inputs=[input_img, style_choice, alter],
231
- outputs=[img_output], cache_examples=True,
232
- run_on_click=True)
233
- ex.dataset.headers = [""]
234
- article = "<p style='text-align: center'><a href='http://claireye.com.tw'>Claireye</a> | 2023</p>"
235
- gr.Markdown(article)
236
-
237
- blocks.launch(enable_queue=True)