File size: 3,505 Bytes
79fcfbd
 
78ea680
79fcfbd
ee69d30
2c738fc
79fcfbd
 
 
78ea680
79fcfbd
4bc008c
79fcfbd
2c738fc
79fcfbd
 
78ea680
2c738fc
9b6512a
78ea680
 
 
 
 
c4cb9b3
79fcfbd
 
 
 
ee69d30
c4cb9b3
ee69d30
 
 
 
 
 
79fcfbd
 
 
 
 
 
 
 
 
 
 
 
 
 
ee69d30
c4cb9b3
90418b8
9b6512a
90418b8
ee69d30
90418b8
 
ee69d30
 
79fcfbd
 
ee69d30
90418b8
ee69d30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79fcfbd
ee69d30
 
79fcfbd
 
 
ee69d30
79fcfbd
ee69d30
79fcfbd
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download

import spaces  # [uncomment to use ZeroGPU]
from diffusers import StableDiffusionXLPipeline
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "stabilityai/stable-diffusion-xl-base-1.0"  # Replace to the model you would like to use

torch_dtype = torch.bfloat16

pipe = StableDiffusionXLPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)

# load pruned model
pruned_pipe = StableDiffusionXLPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pruned_pipe.unet = torch.load(
    hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"),
    map_location="cpu",
)
pruned_pipe = pruned_pipe.to(device)


MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024


@spaces.GPU  # [uncomment to use ZeroGPU]
def generate_images(prompt, seed, steps):
    # Run the model and return images directly
    g_cpu = torch.Generator("cuda").manual_seed(seed)
    original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
    g_cpu = torch.Generator("cuda").manual_seed(seed)
    ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
    return original_image, ecodiff_image


examples = [
    "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
    "An astronaut riding a green horse",
    "A delicious ceviche cheesecake slice",
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""
header = """
# 🌱 EcoDiff Pruned SD-XL (20% Pruning Ratio)
"""

header_2 = """
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
    <a href="https://arxiv.org/abs/2412.02852"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
    <a href="https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
</div>
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown(header)
    gr.HTML(header_2)
    with gr.Row():
        prompt = gr.Textbox(
            label="Prompt",
            value="A clock tower floating in a sea of clouds",
            scale=3,
        )
        seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
        steps = gr.Slider(
            label="Number of Steps",
            minimum=1,
            maximum=100,
            value=50,
            step=1,
            scale=1,
        )
        generate_btn = gr.Button("Generate Images")
    gr.Examples(
        examples=[
            "A clock tower floating in a sea of clouds",
            "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
            "An astronaut riding a green horse",
            "A delicious ceviche cheesecake slice",
            "A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages",
        ],
        inputs=[prompt],
    )
    with gr.Row():
        original_output = gr.Image(label="Original Output")
        ecodiff_output = gr.Image(label="EcoDiff Output")
    gr.on(
        triggers=[generate_btn.click, prompt.submit],
        fn=generate_images,
        inputs=[
            prompt,
            seed,
            steps,
        ],
        outputs=[original_output, ecodiff_output],
    )

if __name__ == "__main__":
    demo.launch()