kaiiddo's picture
requirements.txt
80d002f verified
# main.py
import gradio as gr
import numpy as np
import random
import torch
import os
from diffusers import SanaSprintPipeline
from PIL import Image
# Initialize device and dtype
dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load models
pipe = SanaSprintPipeline.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers",
torch_dtype=dtype
)
pipe2 = SanaSprintPipeline.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
torch_dtype=dtype
)
pipe.to(device)
pipe2.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
def generate_image(prompt, model_size, seed, randomize_seed, width, height, guidance_scale, steps):
if randomize_seed:
seed = random.randint(0, MAX_SEED)
generator = torch.Generator().manual_seed(seed)
selected_pipe = pipe if model_size == "0.6B" else pipe2
result = selected_pipe(
prompt=prompt,
guidance_scale=guidance_scale,
num_inference_steps=steps,
width=width,
height=height,
generator=generator,
output_type="pil"
)
image = result.images[0]
filename = f"output_{seed}.png"
image.save(filename)
return image, filename, seed
css = """
#col-container {
margin: 0 auto;
max-width: 800px;
}
"""
with gr.Blocks(css=css) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown("# πŸš€ Sana Sprint Image Generator")
with gr.Row():
with gr.Column():
prompt = gr.Textbox(
label="Enter Prompt",
placeholder="A surreal landscape with...",
lines=3
)
model_size = gr.Radio(
label="Model Size",
choices=["0.6B", "1.6B"],
value="1.6B"
)
with gr.Accordion("Advanced Settings", open=False):
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
value=42,
step=1
)
randomize_seed = gr.Checkbox(
label="Randomize Seed",
value=True
)
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
value=1024,
step=32
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
value=1024,
step=32
)
guidance_scale = gr.Slider(
label="Guidance Scale",
minimum=1.0,
maximum=15.0,
value=4.5,
step=0.1
)
steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=50,
value=2,
step=1
)
generate_btn = gr.Button("Generate Image", variant="primary")
with gr.Column():
output_image = gr.Image(label="Generated Image")
file_output = gr.File(label="Download Image")
seed_info = gr.Textbox(label="Used Seed")
gr.Examples(
examples=[
["a tiny astronaut hatching from an egg on the moon", "1.6B"],
["🐢 Wearing πŸ•Ά flying on the 🌈", "1.6B"],
["an anime illustration of a wiener schnitzel", "0.6B"]
],
inputs=[prompt, model_size],
outputs=[output_image, file_output, seed_info],
fn=generate_image,
cache_examples=True
)
generate_btn.click(
fn=generate_image,
inputs=[prompt, model_size, seed, randomize_seed, width, height, guidance_scale, steps],
outputs=[output_image, file_output, seed_info]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")