Spaces:
Runtime error
Runtime error
# Stable Diffusion Hugging Face App (Turbo Version with Fixes) | |
import gradio as gr | |
import torch | |
from diffusers import StableDiffusionPipeline, DDIMScheduler | |
# Load the lightweight Stable Diffusion Turbo model | |
model_id = "stabilityai/sd-turbo" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
pipe = StableDiffusionPipeline.from_pretrained( | |
model_id, | |
torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
).to(device) | |
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) | |
# Simulated style prompts (not using learned embeddings) | |
STYLE_MAP = { | |
"Van Gogh": "in the style of Van Gogh", | |
"Cyberpunk": "cyberpunk futuristic cityscape", | |
"Pixel Art": "8-bit pixel art style", | |
"Studio Ghibli": "studio ghibli anime style", | |
"Surrealism": "in surrealistic dreamscape style" | |
} | |
# Custom loss placeholder (for assignment purposes) | |
def custom_loss_placeholder(image_tensor): | |
yellow = torch.tensor([1.0, 1.0, 0.0]).to(image_tensor.device) | |
image_mean = image_tensor.mean(dim=[1, 2]) | |
yellow_loss = torch.nn.functional.mse_loss(image_mean, yellow) | |
return yellow_loss | |
# Generate image based on prompt and style | |
def generate(prompt, style, seed): | |
generator = torch.manual_seed(seed) | |
full_prompt = f"{prompt}, {STYLE_MAP.get(style, '')}" | |
result = pipe(full_prompt, guidance_scale=7.5, generator=generator).images[0] | |
return result | |
# Gradio UI | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown("""# Stable Diffusion Turbo App\nGenerate styled images using text prompts and different art styles.""") | |
with gr.Row(): | |
prompt = gr.Textbox(label="Enter Prompt", placeholder="A fox with a monocle") | |
style = gr.Dropdown(choices=list(STYLE_MAP.keys()), label="Choose Style", value="Van Gogh") | |
seed = gr.Slider(minimum=0, maximum=9999, step=1, value=42, label="Random Seed") | |
generate_btn = gr.Button("Generate Image") | |
output = gr.Image(label="Stylized Output") | |
generate_btn.click(fn=generate, inputs=[prompt, style, seed], outputs=output) | |
# Launch the Gradio app | |
demo.launch() | |