File size: 9,159 Bytes
6fc683c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
# ------------------------------------------
# TextDiffuser: Diffusion Models as Text Painters
# Paper Link: https://arxiv.org/abs/2305.10855
# Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser
# Copyright (c) Microsoft Corporation.
# This file provides the inference script.
# ------------------------------------------

import os
from PIL import Image
import numpy as np
import torch
from tqdm import tqdm
import argparse
import cv2
import torchvision.transforms as transforms

to_pil_image = transforms.ToPILImage()

def load_stablediffusion():
    from diffusers import StableDiffusionPipeline
    pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)      
    pipe.enable_xformers_memory_efficient_attention()
    pipe.enable_model_cpu_offload()
    return pipe

def test_stablediffusion(prompt, save_path, num_images_per_prompt=4,
                              pipe=None, generator=None):
    images = pipe(prompt, num_inference_steps=50, generator=generator, num_images_per_prompt=num_images_per_prompt).images
    for idx, image in enumerate(images):
        image.save(save_path.replace('.jpg', '_' + str(idx) + '.jpg').replace('/images/', '/images_'+ str(idx) +'/'))

def load_deepfloyd_if():
    from diffusers import DiffusionPipeline
    stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16)
    # stage_1.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0
    stage_1.enable_model_cpu_offload()
    stage_2 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16",
                                                torch_dtype=torch.float16)
    # stage_2.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0
    stage_2.enable_model_cpu_offload()
    safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker,
                      "watermarker": stage_1.watermarker}
    stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules,
                                                torch_dtype=torch.float16)
    # stage_3.enable_xformers_memory_efficient_attention()  # remove line if torch.__version__ >= 2.0.0
    stage_3.enable_model_cpu_offload()
    return stage_1, stage_2, stage_3


def test_deepfloyd_if(stage_1, stage_2, stage_3, prompt, save_path, num_images_per_prompt=4, generator=None):
    idx = num_images_per_prompt - 1  # if the last image of a case exists, then return
    new_save_path = save_path.replace('.jpg', '_' + str(idx) + '.jpg').replace('/images/', '/images_' + str(idx) + '/')
    if os.path.exists(new_save_path):
        return
    if not stage_1 or not stage_2 or not stage_3:
        stage_1, stage_2, stage_3  = load_deepfloyd_if()
    if generator is None:
        generator = torch.manual_seed(0)
    prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt)
    stage_1.set_progress_bar_config(disable=True)
    stage_2.set_progress_bar_config(disable=True)
    stage_3.set_progress_bar_config(disable=True)
    images = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator,
                    output_type="pt",  num_images_per_prompt=num_images_per_prompt).images
    for idx, image in enumerate(images):
        image = stage_2(image=image.unsqueeze(0), prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds,
                        generator=generator, output_type="pt").images
        image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100).images
        # image = to_pil_image(image[0].cpu())
        new_save_path = save_path.replace('.jpg', '_' + str(idx) + '.jpg').replace('/images/', '/images_'+ str(idx) +'/')
        image[0].save(new_save_path)


def load_controlnet_cannyedge():
    from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler

    controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16)
    pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet,
                                                             safety_checker=None, torch_dtype=torch.float16)
    pipe.set_progress_bar_config(disable=True)
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.enable_xformers_memory_efficient_attention()
    pipe.enable_model_cpu_offload()
    return pipe


def test_controlnet_cannyedge(prompt, save_path, canny_path, num_images_per_prompt=4,
                              pipe=None, generator=None, low_threshold=100, high_threshold=200):
    '''ref: https://github.com/huggingface/diffusers/blob/131312caba0af97da98fc498dfdca335c9692f8c/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx'''
    from diffusers.utils import load_image
    if pipe is None:
        pipe = load_controlnet_cannyedge()

    if os.path.exists(canny_path):
        canny_path = Image.open(canny_path)
    image = load_image(canny_path)
    image = np.array(image)
    image = cv2.Canny(image, low_threshold, high_threshold)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    image = Image.fromarray(image)
    images = pipe(prompt, image, num_inference_steps=20, generator=generator, num_images_per_prompt=num_images_per_prompt).images
    for idx, image in enumerate(images):
        image.save(save_path.replace('.jpg', '_' + str(idx) + '.jpg').replace('/images/', '/images_'+ str(idx) +'/'))


def MARIOEval_generate_results(root, dataset, method='controlnet', num_images_per_prompt=4, split=0, total_split=1):
    root_eval = os.path.join(root, "MARIOEval")
    render_path = os.path.join(root_eval, dataset, 'render')
    root_res = os.path.join(root, "generation", method)
    for idx in range(num_images_per_prompt): 
        os.makedirs(os.path.join(root_res, dataset, 'images_' + str(idx)), exist_ok=True)
    generator = torch.Generator(device="cuda").manual_seed(0)
    if method == 'controlnet':
        pipe = load_controlnet_cannyedge()
    elif method == 'stablediffusion':
        pipe = load_stablediffusion()
    elif method == 'deepfloyd':
        stage_1, stage_2, stage_3 = load_deepfloyd_if() 
   
    with open(os.path.join(root_eval, dataset, dataset + '.txt'), 'r') as fr:
        prompts = fr.readlines()
        prompts = [_.strip() for _ in prompts]
    for idx, prompt in tqdm(enumerate(prompts)): 
        if idx < split * len(prompts) / total_split or idx > (split + 1) * len(prompts) / total_split:
            continue
        if  method == 'controlnet':
            test_controlnet_cannyedge(prompt=prompt, num_images_per_prompt=num_images_per_prompt,
                                  save_path=os.path.join(root_res, dataset, 'images', str(idx) + '.jpg'),
                                  canny_path=os.path.join(render_path, str(idx) + '.png'),
                                  pipe=pipe, generator=generator) 
        elif  method == 'stablediffusion':
            test_stablediffusion(prompt=prompt, num_images_per_prompt=num_images_per_prompt,
                                  save_path=os.path.join(root_res, dataset, 'images', str(idx) + '.jpg'),
                                  pipe=pipe, generator=generator) 
        elif method == 'deepfloyd':
            test_deepfloyd_if(stage_1, stage_2, stage_3, num_images_per_prompt=num_images_per_prompt,
                              save_path=os.path.join(root_res, dataset, 'images', str(idx) + '.jpg'),
                              prompt=prompt, generator=generator)

def parse_args():
    parser = argparse.ArgumentParser(description="Simple example of a training script.")
    parser.add_argument(
        "--dataset",
        type=str,
        default='TMDBEval500',
        required=False,
        choices=['TMDBEval500', 'OpenLibraryEval500', 'LAIONEval4000',
                 'ChineseDrawText', 'DrawBenchText', 'DrawTextCreative']
    )
    parser.add_argument(
        "--root",
        type=str,
        default="/path/to/eval",
        required=True, 
    )
    parser.add_argument(
        "--method",
        type=str,
        default='controlnet', 
        required=False,
        choices=['controlnet', 'deepfloyd', 'stablediffusion', 'textdiffuser']
    )
    parser.add_argument(
        "--gpu",
        type=int,
        default=0,
        required=False,
    )
    parser.add_argument(
        "--split", 
        type=int,
        default=0,
        required=False,
    )
    parser.add_argument(
        "--total_split", 
        type=int,
        default=1,
        required=False,
    )
    args = parser.parse_args()
    return args


if __name__ == "__main__":
    args = parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)

    MARIOEval_generate_results(root=args.root, dataset=args.dataset, method=args.method,
                               split=args.split, total_split=args.total_split)