File size: 2,110 Bytes
4d1b5c0
0f3522e
 
 
 
 
 
4d1b5c0
 
0f3522e
4d1b5c0
0f3522e
4d1b5c0
0f3522e
 
4d1b5c0
0f3522e
 
4d1b5c0
0f3522e
 
 
 
 
 
 
 
4d1b5c0
0f3522e
 
4d1b5c0
0f3522e
 
 
 
4d1b5c0
0f3522e
 
 
 
 
 
 
4d1b5c0
 
 
 
0f3522e
 
4d1b5c0
0f3522e
 
 
4d1b5c0
0f3522e
 
4d1b5c0
0f3522e
4d1b5c0
0f3522e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# Stable Diffusion Hugging Face App (Turbo Version for Fast Startup)

import gradio as gr
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
from transformers import CLIPTextModel, CLIPTokenizer

# Load the lightweight Stable Diffusion Turbo model
model_id = "stabilityai/sd-turbo"
device = "cuda" if torch.cuda.is_available() else "cpu"

pipe = StableDiffusionPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
).to(device)

pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)

# Simulated style prompts (not using learned embeddings)
STYLE_MAP = {
    "Van Gogh": "in the style of Van Gogh",
    "Cyberpunk": "cyberpunk futuristic cityscape",
    "Pixel Art": "8-bit pixel art style",
    "Studio Ghibli": "studio ghibli anime style",
    "Surrealism": "in surrealistic dreamscape style"
}

# Custom loss placeholder (for assignment purposes)
def custom_loss_placeholder(image_tensor):
    yellow = torch.tensor([1.0, 1.0, 0.0]).to(image_tensor.device)
    image_mean = image_tensor.mean(dim=[1, 2])
    yellow_loss = torch.nn.functional.mse_loss(image_mean, yellow)
    return yellow_loss

# Generate image based on prompt and style

def generate(prompt, style, seed):
    torch.manual_seed(seed)
    full_prompt = f"{prompt}, {STYLE_MAP.get(style, '')}"
    result = pipe(full_prompt, guidance_scale=7.5).images[0]
    return result

# Gradio UI
demo = gr.Blocks()

with demo:
    gr.Markdown("""# Stable Diffusion Turbo App\nGenerate styled images using text prompts and different art styles.""")

    with gr.Row():
        prompt = gr.Textbox(label="Enter Prompt", placeholder="A fox with a monocle")
        style = gr.Dropdown(choices=list(STYLE_MAP.keys()), label="Choose Style", value="Van Gogh")
        seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed")

    generate_btn = gr.Button("Generate Image")
    output = gr.Image(label="Stylized Output")

    generate_btn.click(fn=generate, inputs=[prompt, style, seed], outputs=output)

# Launch the Gradio app
demo.launch()