atharvasc27112001 commited on
Commit
4d1b5c0
·
verified ·
1 Parent(s): 4785dc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -13
app.py CHANGED
@@ -1,19 +1,22 @@
1
- # Stable Diffusion Hugging Face App (Gradio UI + Style Selection + Custom Loss Placeholder)
2
 
3
  import gradio as gr
4
  import torch
5
  from diffusers import StableDiffusionPipeline, DDIMScheduler
6
  from transformers import CLIPTextModel, CLIPTokenizer
7
 
8
- # Load pre-trained models
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
10
  pipe = StableDiffusionPipeline.from_pretrained(
11
- "runwayml/stable-diffusion-v1-5",
12
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
13
  ).to(device)
 
14
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
15
 
16
- # Example styles from textual inversion (simulated via prompts)
17
  STYLE_MAP = {
18
  "Van Gogh": "in the style of Van Gogh",
19
  "Cyberpunk": "cyberpunk futuristic cityscape",
@@ -22,15 +25,15 @@ STYLE_MAP = {
22
  "Surrealism": "in surrealistic dreamscape style"
23
  }
24
 
25
- # Custom loss placeholder (not applied at inference, for academic purposes)
26
  def custom_loss_placeholder(image_tensor):
27
- # Example: "yellow_loss" = penalize lack of yellow pixels
28
  yellow = torch.tensor([1.0, 1.0, 0.0]).to(image_tensor.device)
29
- image_mean = image_tensor.mean(dim=[1, 2]) # Average over H and W
30
  yellow_loss = torch.nn.functional.mse_loss(image_mean, yellow)
31
  return yellow_loss
32
 
33
  # Generate image based on prompt and style
 
34
  def generate(prompt, style, seed):
35
  torch.manual_seed(seed)
36
  full_prompt = f"{prompt}, {STYLE_MAP.get(style, '')}"
@@ -38,18 +41,20 @@ def generate(prompt, style, seed):
38
  return result
39
 
40
  # Gradio UI
41
- with gr.Blocks() as demo:
42
- gr.Markdown("""# Stable Diffusion Style Generator\nGenerate styled images using Stable Diffusion + Textual Inversion Styles.""")
 
 
43
 
44
  with gr.Row():
45
- prompt = gr.Textbox(label="Enter Prompt", placeholder="A cat riding a bicycle through space")
46
  style = gr.Dropdown(choices=list(STYLE_MAP.keys()), label="Choose Style", value="Van Gogh")
47
  seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed")
48
 
49
- btn = gr.Button("Generate Image")
50
  output = gr.Image(label="Stylized Output")
51
 
52
- btn.click(fn=generate, inputs=[prompt, style, seed], outputs=output)
53
 
54
- # Launch app
55
  demo.launch()
 
1
+ # Stable Diffusion Hugging Face App (Turbo Version for Fast Startup)
2
 
3
  import gradio as gr
4
  import torch
5
  from diffusers import StableDiffusionPipeline, DDIMScheduler
6
  from transformers import CLIPTextModel, CLIPTokenizer
7
 
8
+ # Load the lightweight Stable Diffusion Turbo model
9
+ model_id = "stabilityai/sd-turbo"
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
  pipe = StableDiffusionPipeline.from_pretrained(
13
+ model_id,
14
  torch_dtype=torch.float16 if device == "cuda" else torch.float32,
15
  ).to(device)
16
+
17
  pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
18
 
19
+ # Simulated style prompts (not using learned embeddings)
20
  STYLE_MAP = {
21
  "Van Gogh": "in the style of Van Gogh",
22
  "Cyberpunk": "cyberpunk futuristic cityscape",
 
25
  "Surrealism": "in surrealistic dreamscape style"
26
  }
27
 
28
+ # Custom loss placeholder (for assignment purposes)
29
  def custom_loss_placeholder(image_tensor):
 
30
  yellow = torch.tensor([1.0, 1.0, 0.0]).to(image_tensor.device)
31
+ image_mean = image_tensor.mean(dim=[1, 2])
32
  yellow_loss = torch.nn.functional.mse_loss(image_mean, yellow)
33
  return yellow_loss
34
 
35
  # Generate image based on prompt and style
36
+
37
  def generate(prompt, style, seed):
38
  torch.manual_seed(seed)
39
  full_prompt = f"{prompt}, {STYLE_MAP.get(style, '')}"
 
41
  return result
42
 
43
  # Gradio UI
44
+ demo = gr.Blocks()
45
+
46
+ with demo:
47
+ gr.Markdown("""# Stable Diffusion Turbo App\nGenerate styled images using text prompts and different art styles.""")
48
 
49
  with gr.Row():
50
+ prompt = gr.Textbox(label="Enter Prompt", placeholder="A fox with a monocle")
51
  style = gr.Dropdown(choices=list(STYLE_MAP.keys()), label="Choose Style", value="Van Gogh")
52
  seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed")
53
 
54
+ generate_btn = gr.Button("Generate Image")
55
  output = gr.Image(label="Stylized Output")
56
 
57
+ generate_btn.click(fn=generate, inputs=[prompt, style, seed], outputs=output)
58
 
59
+ # Launch the Gradio app
60
  demo.launch()