diffusion / app.py
atharvasc27112001's picture
Update app.py
c8ae620 verified
# Stable Diffusion Hugging Face App (Turbo Version with Fixes)
import gradio as gr
import torch
from diffusers import StableDiffusionPipeline, DDIMScheduler
# Load the lightweight Stable Diffusion Turbo model
model_id = "stabilityai/sd-turbo"
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
).to(device)
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
# Simulated style prompts (not using learned embeddings)
STYLE_MAP = {
"Van Gogh": "in the style of Van Gogh",
"Cyberpunk": "cyberpunk futuristic cityscape",
"Pixel Art": "8-bit pixel art style",
"Studio Ghibli": "studio ghibli anime style",
"Surrealism": "in surrealistic dreamscape style"
}
# Custom loss placeholder (for assignment purposes)
def custom_loss_placeholder(image_tensor):
yellow = torch.tensor([1.0, 1.0, 0.0]).to(image_tensor.device)
image_mean = image_tensor.mean(dim=[1, 2])
yellow_loss = torch.nn.functional.mse_loss(image_mean, yellow)
return yellow_loss
# Generate image based on prompt and style
def generate(prompt, style, seed):
generator = torch.manual_seed(seed)
full_prompt = f"{prompt}, {STYLE_MAP.get(style, '')}"
result = pipe(full_prompt, guidance_scale=7.5, generator=generator).images[0]
return result
# Gradio UI
demo = gr.Blocks()
with demo:
gr.Markdown("""# Stable Diffusion Turbo App\nGenerate styled images using text prompts and different art styles.""")
with gr.Row():
prompt = gr.Textbox(label="Enter Prompt", placeholder="A fox with a monocle")
style = gr.Dropdown(choices=list(STYLE_MAP.keys()), label="Choose Style", value="Van Gogh")
seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed")
generate_btn = gr.Button("Generate Image")
output = gr.Image(label="Stylized Output")
generate_btn.click(fn=generate, inputs=[prompt, style, seed], outputs=output)
# Launch the Gradio app
demo.launch()