Realtime-FLUX / app.py
KingNish's picture
Update app.py
3db3884 verified
import gradio as gr
import numpy as np
import random
import spaces
import torch
import time
from diffusers import DiffusionPipeline, AutoencoderTiny
from custom_pipeline import FluxWithCFGPipeline
# --- Torch Optimizations ---
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True # Enable cuDNN benchmark for potentially faster convolutions
# --- Constants ---
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 2048 # Keep a reasonable limit to prevent OOMs
DEFAULT_WIDTH = 1024
DEFAULT_HEIGHT = 1024
DEFAULT_INFERENCE_STEPS = 1 # FLUX Schnell is designed for few steps
MIN_INFERENCE_STEPS = 1
MAX_INFERENCE_STEPS = 8 # Allow slightly more steps for potential quality boost
ENHANCE_STEPS = 2 # Fixed steps for the enhance button
# --- Device and Model Setup ---
dtype = torch.float16
device = "cuda" if torch.cuda.is_available() else "cpu"
pipe = FluxWithCFGPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell", torch_dtype=dtype)
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taef1", torch_dtype=dtype)
pipe.to(device)
# --- Inference Function ---
@spaces.GPU
def generate_image(prompt: str, seed: int = 42, width: int = DEFAULT_WIDTH, height: int = DEFAULT_HEIGHT, randomize_seed: bool = False, num_inference_steps: int = DEFAULT_INFERENCE_STEPS, is_enhance: bool = False):
"""Generates an image using the FLUX pipeline with error handling."""
if pipe is None:
raise gr.Error("Diffusion pipeline failed to load. Cannot generate images.")
if not prompt or prompt.strip() == "":
gr.Warning("Prompt is empty. Please enter a description.")
return None, seed, "Error: Empty prompt"
start_time = time.time()
if randomize_seed:
seed = random.randint(0, MAX_SEED)
# Clamp dimensions to avoid excessive memory usage
width = min(width, MAX_IMAGE_SIZE)
height = min(height, MAX_IMAGE_SIZE)
# Use fixed steps for enhance button, otherwise use slider value
steps_to_use = ENHANCE_STEPS if is_enhance else num_inference_steps
# Clamp steps
steps_to_use = max(MIN_INFERENCE_STEPS, min(steps_to_use, MAX_INFERENCE_STEPS))
try:
# Ensure generator is on the correct device
generator = torch.Generator(device=device).manual_seed(int(float(seed)))
# Use inference_mode for efficiency
with torch.inference_mode():
# Generate the image (assuming pipe returns list/tuple with image first)
# Modify pipe call based on its actual signature if needed
result_img = pipe(
prompt=prompt,
width=width,
height=height,
num_inference_steps=steps_to_use,
generator=generator,
output_type="pil", # Ensure PIL output for Gradio Image component
return_dict=False # Assuming the custom pipeline supports this for direct output
)[0][0] # Assuming the output structure is [[img]]
latency = time.time() - start_time
latency_str = f"Latency: {latency:.2f} seconds (Steps: {steps_to_use})"
return result_img, seed, latency_str
except torch.cuda.OutOfMemoryError as e:
# Clear cache and suggest reducing size/steps
if torch.cuda.is_available():
torch.cuda.empty_cache()
raise gr.Error("GPU ran out of memory. Try reducing the image width/height or the number of inference steps.")
except Exception as e:
# Clear cache just in case
if torch.cuda.is_available():
torch.cuda.empty_cache()
raise gr.Error(f"An error occurred during generation: {e}")
# --- Example Prompts ---
examples = [
"a tiny astronaut hatching from an egg on the moon",
"a cute white cat holding a sign that says hello world",
"an anime illustration of Steve Jobs",
"Create image of Modern house in minecraft style",
"photo of a woman on the beach, shot from above. She is facing the sea, while wearing a white dress. She has long blonde hair",
"Selfie photo of a wizard with long beard and purple robes, he is apparently in the middle of Tokyo. Probably taken from a phone.",
"Photo of a young woman with long, wavy brown hair tied in a bun and glasses. She has a fair complexion and is wearing subtle makeup, emphasizing her eyes and lips. She is dressed in a black top. The background appears to be an urban setting with a building facade, and the sunlight casts a warm glow on her face.",
"High-resolution photorealistic render of a sleek, futuristic motorcycle parked on a neon-lit street at night, rain reflecting the lights.",
"Watercolor painting of a cozy bookstore interior with overflowing shelves and a cat sleeping in a sunbeam.",
]
# --- Gradio UI ---
with gr.Blocks() as demo:
with gr.Column(elem_id="app-container"):
gr.Markdown("# 🎨 Realtime FLUX Image Generator")
gr.Markdown("Generate stunning images in real-time with Modified Flux.Schnell pipeline.")
gr.Markdown("<span style='color: red;'>Note: Sometimes it stucks or stops generating images (I don't know why). In that situation just refresh the site.</span>")
with gr.Row():
with gr.Column(scale=2.5):
result = gr.Image(label="Generated Image", show_label=False, interactive=False)
with gr.Column(scale=1):
prompt = gr.Text(
label="Prompt",
placeholder="Describe the image you want to generate...",
lines=3,
show_label=False,
container=False,
)
generateBtn = gr.Button("🖼️ Generate Image")
enhanceBtn = gr.Button("🚀 Enhance Image")
with gr.Column("Advanced Options"):
with gr.Row():
realtime = gr.Checkbox(label="Realtime Toggler", info="If TRUE then uses more GPU but create image in realtime.", value=False)
latency = gr.Text(label="Latency")
with gr.Row():
seed = gr.Number(label="Seed", value=42)
randomize_seed = gr.Checkbox(label="Randomize Seed", value=True)
with gr.Row():
width = gr.Slider(label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_WIDTH)
height = gr.Slider(label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=DEFAULT_HEIGHT)
num_inference_steps = gr.Slider(label="Inference Steps", minimum=MIN_INFERENCE_STEPS, maximum=MAX_INFERENCE_STEPS, step=1, value=DEFAULT_INFERENCE_STEPS)
with gr.Row():
gr.Markdown("### 🌟 Inspiration Gallery")
with gr.Row():
gr.Examples(
examples=examples,
fn=generate_image,
inputs=[prompt],
outputs=[result, seed, latency],
cache_examples=True,
cache_mode="eager"
)
enhanceBtn.click(
fn=generate_image,
inputs=[prompt, seed, width, height],
outputs=[result, seed, latency],
show_progress="full"
)
generateBtn.click(
fn=generate_image,
inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
outputs=[result, seed, latency],
show_progress="full",
api_name="RealtimeFlux",
)
def update_ui(realtime_enabled):
return {
prompt: gr.update(interactive=True),
generateBtn: gr.update(visible=not realtime_enabled)
}
def realtime_generation(*args):
if args[0]: # If realtime is enabled
return next(generate_image(*args[1:]))
realtime.change(
fn=update_ui,
inputs=[realtime],
outputs=[prompt, generateBtn]
)
prompt.submit(
fn=generate_image,
inputs=[prompt, seed, width, height, randomize_seed, num_inference_steps],
outputs=[result, seed, latency],
show_progress="full"
)
for component in [prompt, width, height, num_inference_steps]:
component.input(
fn=realtime_generation,
inputs=[realtime, prompt, seed, width, height, randomize_seed, num_inference_steps],
outputs=[result, seed, latency],
show_progress="hidden",
trigger_mode="always_last"
)
# Launch the app
demo.launch()