Spaces:
Runtime error
Runtime error
File size: 3,505 Bytes
79fcfbd 78ea680 79fcfbd ee69d30 2c738fc 79fcfbd 78ea680 79fcfbd 4bc008c 79fcfbd 2c738fc 79fcfbd 78ea680 2c738fc 9b6512a 78ea680 c4cb9b3 79fcfbd ee69d30 c4cb9b3 ee69d30 79fcfbd ee69d30 c4cb9b3 90418b8 9b6512a 90418b8 ee69d30 90418b8 ee69d30 79fcfbd ee69d30 90418b8 ee69d30 79fcfbd ee69d30 79fcfbd ee69d30 79fcfbd ee69d30 79fcfbd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 |
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
import spaces # [uncomment to use ZeroGPU]
from diffusers import StableDiffusionXLPipeline
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
model_repo_id = "stabilityai/stable-diffusion-xl-base-1.0" # Replace to the model you would like to use
torch_dtype = torch.bfloat16
pipe = StableDiffusionXLPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pipe = pipe.to(device)
# load pruned model
pruned_pipe = StableDiffusionXLPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
pruned_pipe.unet = torch.load(
hf_hub_download("zhangyang-0123/EcoDiffPrunedModels", "model/sdxl/sdxl.pkl"),
map_location="cpu",
)
pruned_pipe = pruned_pipe.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
@spaces.GPU # [uncomment to use ZeroGPU]
def generate_images(prompt, seed, steps):
# Run the model and return images directly
g_cpu = torch.Generator("cuda").manual_seed(seed)
original_image = pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
g_cpu = torch.Generator("cuda").manual_seed(seed)
ecodiff_image = pruned_pipe(prompt=prompt, generator=g_cpu, num_inference_steps=steps).images[0]
return original_image, ecodiff_image
examples = [
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
header = """
# 🌱 EcoDiff Pruned SD-XL (20% Pruning Ratio)
"""
header_2 = """
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
<a href="https://arxiv.org/abs/2412.02852"><img src="https://img.shields.io/badge/ariXv-Paper-A42C25.svg" alt="arXiv"></a>
<a href="https://huggingface.co/zhangyang-0123/EcoDiffPrunedModels"><img src="https://img.shields.io/badge/🤗-Model-ffbd45.svg" alt="HuggingFace"></a>
</div>
"""
with gr.Blocks(css=css) as demo:
gr.Markdown(header)
gr.HTML(header_2)
with gr.Row():
prompt = gr.Textbox(
label="Prompt",
value="A clock tower floating in a sea of clouds",
scale=3,
)
seed = gr.Number(label="Seed", value=44, precision=0, scale=1)
steps = gr.Slider(
label="Number of Steps",
minimum=1,
maximum=100,
value=50,
step=1,
scale=1,
)
generate_btn = gr.Button("Generate Images")
gr.Examples(
examples=[
"A clock tower floating in a sea of clouds",
"Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
"An astronaut riding a green horse",
"A delicious ceviche cheesecake slice",
"A sprawling cyberpunk metropolis at night, with towering skyscrapers emitting neon lights of every color, holographic billboards advertising alien languages",
],
inputs=[prompt],
)
with gr.Row():
original_output = gr.Image(label="Original Output")
ecodiff_output = gr.Image(label="EcoDiff Output")
gr.on(
triggers=[generate_btn.click, prompt.submit],
fn=generate_images,
inputs=[
prompt,
seed,
steps,
],
outputs=[original_output, ecodiff_output],
)
if __name__ == "__main__":
demo.launch()
|