File size: 5,043 Bytes
5e673fa
f9854a2
5e673fa
 
 
 
c7c8e9e
5e673fa
 
 
 
 
 
 
 
c7c8e9e
 
 
 
5e673fa
c7c8e9e
5e673fa
 
 
 
c7c8e9e
5e673fa
 
 
 
c7c8e9e
 
 
 
5e673fa
 
 
 
 
 
9975067
5e673fa
9975067
 
5e673fa
c7c8e9e
 
 
 
 
5e673fa
 
c7c8e9e
 
 
 
 
 
 
 
5e673fa
 
 
 
 
 
 
 
 
 
c7c8e9e
 
 
 
 
 
 
 
 
5e673fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5aebc40
5e673fa
 
c7c8e9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5e673fa
 
 
 
c7c8e9e
5e673fa
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import gradio as gr
import spaces
import numpy as np
import random
import spaces
import torch
from diffusers import SanaSprintPipeline

dtype = torch.bfloat16
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = SanaSprintPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_Sprint_0.6B_1024px_diffusers",
    torch_dtype=torch.bfloat16
)
pipe2 = SanaSprintPipeline.from_pretrained(
    "Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
    torch_dtype=torch.bfloat16
)
pipe.to(device)
pipe2.to(device)
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024

@spaces.GPU(duration=5)
def infer(prompt, model_size, seed=42, randomize_seed=False, width=1024, height=1024, guidance_scale=3.5, num_inference_steps=2, progress=gr.Progress(track_tqdm=True)):
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    generator = torch.Generator().manual_seed(seed)
    
    # Choose the appropriate model based on selected model size
    selected_pipe = pipe if model_size == "0.6B" else pipe2
    
    img = selected_pipe(
            prompt=prompt,
            guidance_scale=guidance_scale,
            num_inference_steps=num_inference_steps,
            width=width,
            height=height,
            generator=generator,
            output_type="pil"
    )
    print(img)
    return img.images[0], seed
    
# Different examples for each model size
examples_06B = [
    "a majestic castle on a floating island",
    "a robotic chef cooking in a futuristic kitchen",
    "a magical forest with glowing mushrooms"
]

examples_16B = [
    "a steampunk city with airships in the sky",
    "a photorealistic fox in a snowy landscape",
    "an underwater temple with ancient ruins"
]

# We'll use the appropriate set based on the model selection

css="""
#col-container {
    margin: 0 auto;
    max-width: 520px;
}
"""

with gr.Blocks(css=css) as demo:
    
    with gr.Column(elem_id="col-container"):
        gr.Markdown(f"""# Sana Sprint""")
        
        # Add radio button for model selection
        model_size = gr.Radio(
            label="Model Size",
            choices=["0.6B", "1.6B"],
            value="0.6B",
            interactive=True
        )
        
        with gr.Row():
            
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=1,
                placeholder="Enter your prompt",
                container=False,
            )
            
            run_button = gr.Button("Run", scale=0)
        
        result = gr.Image(label="Result", show_label=False)
        
        with gr.Accordion("Advanced Settings", open=False):
            
            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
            
            with gr.Row():
                
                width = gr.Slider(
                    label="Width",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
                
                height = gr.Slider(
                    label="Height",
                    minimum=256,
                    maximum=MAX_IMAGE_SIZE,
                    step=32,
                    value=1024,
                )
            
            with gr.Row():

                guidance_scale = gr.Slider(
                    label="Guidance Scale",
                    minimum=1,
                    maximum=15,
                    step=0.1,
                    value=1,
                )
  
                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=2,
                )
        
        with gr.Row():
            examples_container = gr.Examples(
                examples = examples_06B,  # Start with 0.6B examples
                fn = infer,
                inputs = [prompt, model_size],
                outputs = [result, seed],
                cache_examples="lazy",
                label="Example Prompts"
            )
            
        # Update examples when model size changes
        def update_examples(model_choice):
            if model_choice == "0.6B":
                return gr.Examples.update(examples=examples_06B)
            else:
                return gr.Examples.update(examples=examples_16B)
                
        model_size.change(fn=update_examples, inputs=[model_size], outputs=[examples_container])

    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn = infer,
        inputs = [prompt, model_size, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],  # Add model_size to inputs
        outputs = [result, seed]
    )

demo.launch()