Spaces:
Sleeping
Sleeping
# ------------------------------------------ | |
# TextDiffuser: Diffusion Models as Text Painters | |
# Paper Link: https://arxiv.org/abs/2305.10855 | |
# Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser | |
# Copyright (c) Microsoft Corporation. | |
# This file provides the inference script. | |
# ------------------------------------------ | |
import os | |
from PIL import Image | |
import numpy as np | |
import torch | |
from tqdm import tqdm | |
import argparse | |
import cv2 | |
import torchvision.transforms as transforms | |
to_pil_image = transforms.ToPILImage() | |
def load_stablediffusion(): | |
from diffusers import StableDiffusionPipeline | |
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) | |
pipe.enable_xformers_memory_efficient_attention() | |
pipe.enable_model_cpu_offload() | |
return pipe | |
def test_stablediffusion(prompt, save_path, num_images_per_prompt=4, | |
pipe=None, generator=None): | |
images = pipe(prompt, num_inference_steps=50, generator=generator, num_images_per_prompt=num_images_per_prompt).images | |
for idx, image in enumerate(images): | |
image.save(save_path.replace('.jpg', '_' + str(idx) + '.jpg').replace('/images/', '/images_'+ str(idx) +'/')) | |
def load_deepfloyd_if(): | |
from diffusers import DiffusionPipeline | |
stage_1 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-I-XL-v1.0", variant="fp16", torch_dtype=torch.float16) | |
# stage_1.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 | |
stage_1.enable_model_cpu_offload() | |
stage_2 = DiffusionPipeline.from_pretrained("DeepFloyd/IF-II-L-v1.0", text_encoder=None, variant="fp16", | |
torch_dtype=torch.float16) | |
# stage_2.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 | |
stage_2.enable_model_cpu_offload() | |
safety_modules = {"feature_extractor": stage_1.feature_extractor, "safety_checker": stage_1.safety_checker, | |
"watermarker": stage_1.watermarker} | |
stage_3 = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-x4-upscaler", **safety_modules, | |
torch_dtype=torch.float16) | |
# stage_3.enable_xformers_memory_efficient_attention() # remove line if torch.__version__ >= 2.0.0 | |
stage_3.enable_model_cpu_offload() | |
return stage_1, stage_2, stage_3 | |
def test_deepfloyd_if(stage_1, stage_2, stage_3, prompt, save_path, num_images_per_prompt=4, generator=None): | |
idx = num_images_per_prompt - 1 # if the last image of a case exists, then return | |
new_save_path = save_path.replace('.jpg', '_' + str(idx) + '.jpg').replace('/images/', '/images_' + str(idx) + '/') | |
if os.path.exists(new_save_path): | |
return | |
if not stage_1 or not stage_2 or not stage_3: | |
stage_1, stage_2, stage_3 = load_deepfloyd_if() | |
if generator is None: | |
generator = torch.manual_seed(0) | |
prompt_embeds, negative_embeds = stage_1.encode_prompt(prompt) | |
stage_1.set_progress_bar_config(disable=True) | |
stage_2.set_progress_bar_config(disable=True) | |
stage_3.set_progress_bar_config(disable=True) | |
images = stage_1(prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, generator=generator, | |
output_type="pt", num_images_per_prompt=num_images_per_prompt).images | |
for idx, image in enumerate(images): | |
image = stage_2(image=image.unsqueeze(0), prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_embeds, | |
generator=generator, output_type="pt").images | |
image = stage_3(prompt=prompt, image=image, generator=generator, noise_level=100).images | |
# image = to_pil_image(image[0].cpu()) | |
new_save_path = save_path.replace('.jpg', '_' + str(idx) + '.jpg').replace('/images/', '/images_'+ str(idx) +'/') | |
image[0].save(new_save_path) | |
def load_controlnet_cannyedge(): | |
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler | |
controlnet = ControlNetModel.from_pretrained("lllyasviel/sd-controlnet-canny", torch_dtype=torch.float16) | |
pipe = StableDiffusionControlNetPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", controlnet=controlnet, | |
safety_checker=None, torch_dtype=torch.float16) | |
pipe.set_progress_bar_config(disable=True) | |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) | |
pipe.enable_xformers_memory_efficient_attention() | |
pipe.enable_model_cpu_offload() | |
return pipe | |
def test_controlnet_cannyedge(prompt, save_path, canny_path, num_images_per_prompt=4, | |
pipe=None, generator=None, low_threshold=100, high_threshold=200): | |
'''ref: https://github.com/huggingface/diffusers/blob/131312caba0af97da98fc498dfdca335c9692f8c/docs/source/en/api/pipelines/stable_diffusion/controlnet.mdx''' | |
from diffusers.utils import load_image | |
if pipe is None: | |
pipe = load_controlnet_cannyedge() | |
if os.path.exists(canny_path): | |
canny_path = Image.open(canny_path) | |
image = load_image(canny_path) | |
image = np.array(image) | |
image = cv2.Canny(image, low_threshold, high_threshold) | |
image = image[:, :, None] | |
image = np.concatenate([image, image, image], axis=2) | |
image = Image.fromarray(image) | |
images = pipe(prompt, image, num_inference_steps=20, generator=generator, num_images_per_prompt=num_images_per_prompt).images | |
for idx, image in enumerate(images): | |
image.save(save_path.replace('.jpg', '_' + str(idx) + '.jpg').replace('/images/', '/images_'+ str(idx) +'/')) | |
def MARIOEval_generate_results(root, dataset, method='controlnet', num_images_per_prompt=4, split=0, total_split=1): | |
root_eval = os.path.join(root, "MARIOEval") | |
render_path = os.path.join(root_eval, dataset, 'render') | |
root_res = os.path.join(root, "generation", method) | |
for idx in range(num_images_per_prompt): | |
os.makedirs(os.path.join(root_res, dataset, 'images_' + str(idx)), exist_ok=True) | |
generator = torch.Generator(device="cuda").manual_seed(0) | |
if method == 'controlnet': | |
pipe = load_controlnet_cannyedge() | |
elif method == 'stablediffusion': | |
pipe = load_stablediffusion() | |
elif method == 'deepfloyd': | |
stage_1, stage_2, stage_3 = load_deepfloyd_if() | |
with open(os.path.join(root_eval, dataset, dataset + '.txt'), 'r') as fr: | |
prompts = fr.readlines() | |
prompts = [_.strip() for _ in prompts] | |
for idx, prompt in tqdm(enumerate(prompts)): | |
if idx < split * len(prompts) / total_split or idx > (split + 1) * len(prompts) / total_split: | |
continue | |
if method == 'controlnet': | |
test_controlnet_cannyedge(prompt=prompt, num_images_per_prompt=num_images_per_prompt, | |
save_path=os.path.join(root_res, dataset, 'images', str(idx) + '.jpg'), | |
canny_path=os.path.join(render_path, str(idx) + '.png'), | |
pipe=pipe, generator=generator) | |
elif method == 'stablediffusion': | |
test_stablediffusion(prompt=prompt, num_images_per_prompt=num_images_per_prompt, | |
save_path=os.path.join(root_res, dataset, 'images', str(idx) + '.jpg'), | |
pipe=pipe, generator=generator) | |
elif method == 'deepfloyd': | |
test_deepfloyd_if(stage_1, stage_2, stage_3, num_images_per_prompt=num_images_per_prompt, | |
save_path=os.path.join(root_res, dataset, 'images', str(idx) + '.jpg'), | |
prompt=prompt, generator=generator) | |
def parse_args(): | |
parser = argparse.ArgumentParser(description="Simple example of a training script.") | |
parser.add_argument( | |
"--dataset", | |
type=str, | |
default='TMDBEval500', | |
required=False, | |
choices=['TMDBEval500', 'OpenLibraryEval500', 'LAIONEval4000', | |
'ChineseDrawText', 'DrawBenchText', 'DrawTextCreative'] | |
) | |
parser.add_argument( | |
"--root", | |
type=str, | |
default="/path/to/eval", | |
required=True, | |
) | |
parser.add_argument( | |
"--method", | |
type=str, | |
default='controlnet', | |
required=False, | |
choices=['controlnet', 'deepfloyd', 'stablediffusion', 'textdiffuser'] | |
) | |
parser.add_argument( | |
"--gpu", | |
type=int, | |
default=0, | |
required=False, | |
) | |
parser.add_argument( | |
"--split", | |
type=int, | |
default=0, | |
required=False, | |
) | |
parser.add_argument( | |
"--total_split", | |
type=int, | |
default=1, | |
required=False, | |
) | |
args = parser.parse_args() | |
return args | |
if __name__ == "__main__": | |
args = parse_args() | |
os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) | |
MARIOEval_generate_results(root=args.root, dataset=args.dataset, method=args.method, | |
split=args.split, total_split=args.total_split) | |