kaiiddo commited on
Commit
80d002f
·
verified ·
1 Parent(s): e6dcfff

requirements.txt

Browse files

gradio>=4.0.0
torch>=2.0.0
diffusers>=0.25.0
transformers>=4.35.0
accelerate>=0.24.0
sentencepiece>=0.1.99
pillow>=10.0.0
numpy>=1.24.0

Files changed (1) hide show
  1. app.py +147 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import gradio as gr
3
+ import numpy as np
4
+ import random
5
+ import torch
6
+ import os
7
+ from diffusers import SanaSprintPipeline
8
+ from PIL import Image
9
+
10
+ # Initialize device and dtype
11
+ dtype = torch.bfloat16
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ # Load models
15
+ pipe = SanaSprintPipeline.from_pretrained(
16
+ "Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers",
17
+ torch_dtype=dtype
18
+ )
19
+ pipe2 = SanaSprintPipeline.from_pretrained(
20
+ "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
21
+ torch_dtype=dtype
22
+ )
23
+ pipe.to(device)
24
+ pipe2.to(device)
25
+
26
+ MAX_SEED = np.iinfo(np.int32).max
27
+ MAX_IMAGE_SIZE = 1024
28
+
29
+ def generate_image(prompt, model_size, seed, randomize_seed, width, height, guidance_scale, steps):
30
+ if randomize_seed:
31
+ seed = random.randint(0, MAX_SEED)
32
+ generator = torch.Generator().manual_seed(seed)
33
+
34
+ selected_pipe = pipe if model_size == "0.6B" else pipe2
35
+
36
+ result = selected_pipe(
37
+ prompt=prompt,
38
+ guidance_scale=guidance_scale,
39
+ num_inference_steps=steps,
40
+ width=width,
41
+ height=height,
42
+ generator=generator,
43
+ output_type="pil"
44
+ )
45
+
46
+ image = result.images[0]
47
+ filename = f"output_{seed}.png"
48
+ image.save(filename)
49
+ return image, filename, seed
50
+
51
+ css = """
52
+ #col-container {
53
+ margin: 0 auto;
54
+ max-width: 800px;
55
+ }
56
+ """
57
+
58
+ with gr.Blocks(css=css) as demo:
59
+ with gr.Column(elem_id="col-container"):
60
+ gr.Markdown("# 🚀 Sana Sprint Image Generator")
61
+
62
+ with gr.Row():
63
+ with gr.Column():
64
+ prompt = gr.Textbox(
65
+ label="Enter Prompt",
66
+ placeholder="A surreal landscape with...",
67
+ lines=3
68
+ )
69
+
70
+ model_size = gr.Radio(
71
+ label="Model Size",
72
+ choices=["0.6B", "1.6B"],
73
+ value="1.6B"
74
+ )
75
+
76
+ with gr.Accordion("Advanced Settings", open=False):
77
+ seed = gr.Slider(
78
+ label="Seed",
79
+ minimum=0,
80
+ maximum=MAX_SEED,
81
+ value=42,
82
+ step=1
83
+ )
84
+ randomize_seed = gr.Checkbox(
85
+ label="Randomize Seed",
86
+ value=True
87
+ )
88
+
89
+ with gr.Row():
90
+ width = gr.Slider(
91
+ label="Width",
92
+ minimum=256,
93
+ maximum=MAX_IMAGE_SIZE,
94
+ value=1024,
95
+ step=32
96
+ )
97
+ height = gr.Slider(
98
+ label="Height",
99
+ minimum=256,
100
+ maximum=MAX_IMAGE_SIZE,
101
+ value=1024,
102
+ step=32
103
+ )
104
+
105
+ guidance_scale = gr.Slider(
106
+ label="Guidance Scale",
107
+ minimum=1.0,
108
+ maximum=15.0,
109
+ value=4.5,
110
+ step=0.1
111
+ )
112
+
113
+ steps = gr.Slider(
114
+ label="Inference Steps",
115
+ minimum=1,
116
+ maximum=50,
117
+ value=2,
118
+ step=1
119
+ )
120
+
121
+ generate_btn = gr.Button("Generate Image", variant="primary")
122
+
123
+ with gr.Column():
124
+ output_image = gr.Image(label="Generated Image")
125
+ file_output = gr.File(label="Download Image")
126
+ seed_info = gr.Textbox(label="Used Seed")
127
+
128
+ gr.Examples(
129
+ examples=[
130
+ ["a tiny astronaut hatching from an egg on the moon", "1.6B"],
131
+ ["🐶 Wearing 🕶 flying on the 🌈", "1.6B"],
132
+ ["an anime illustration of a wiener schnitzel", "0.6B"]
133
+ ],
134
+ inputs=[prompt, model_size],
135
+ outputs=[output_image, file_output, seed_info],
136
+ fn=generate_image,
137
+ cache_examples=True
138
+ )
139
+
140
+ generate_btn.click(
141
+ fn=generate_image,
142
+ inputs=[prompt, model_size, seed, randomize_seed, width, height, guidance_scale, steps],
143
+ outputs=[output_image, file_output, seed_info]
144
+ )
145
+
146
+ if __name__ == "__main__":
147
+ demo.launch(server_name="0.0.0.0")