Spaces:
Running
on
Zero
Running
on
Zero
import logging | |
import random | |
import warnings | |
import os | |
import gradio as gr | |
import numpy as np | |
import spaces | |
import torch | |
from diffusers import FluxControlNetModel | |
from diffusers.pipelines import FluxControlNetPipeline | |
from gradio_imageslider import ImageSlider | |
from PIL import Image | |
from huggingface_hub import snapshot_download | |
# Define custom CSS styling for Gradio blocks | |
css = """ | |
#col-container { | |
margin: 0 auto; | |
max-width: 512px; | |
} | |
""" | |
# Determine whether GPU is available, and set the device accordingly | |
if torch.cuda.is_available(): | |
power_device = "GPU" | |
device = "cuda" | |
print("GPU is available. Using CUDA.") | |
else: | |
power_device = "CPU" | |
device = "cpu" | |
print("GPU is not available. Using CPU.") | |
# Get Hugging Face token from environment variables | |
huggingface_token = os.getenv("HUGGINGFACE_TOKEN") | |
print(f"Hugging Face token retrieved: {huggingface_token is not None}") | |
# Download the model from the Hugging Face Hub | |
print("Downloading model from Hugging Face Hub...") | |
model_path = snapshot_download( | |
repo_id="black-forest-labs/FLUX.1-dev", | |
repo_type="model", | |
ignore_patterns=["*.md", "*..gitattributes"], | |
local_dir="FLUX.1-dev", | |
token=huggingface_token, | |
) | |
print(f"Model downloaded to: {model_path}") | |
# Load ControlNet model | |
print("Loading ControlNet model...") | |
controlnet = FluxControlNetModel.from_pretrained( | |
"jasperai/Flux.1-dev-Controlnet-Upscaler", torch_dtype=torch.bfloat16 | |
).to(device) | |
print("ControlNet model loaded.") | |
# Load the pipeline using the downloaded model and ControlNet | |
print("Loading FluxControlNetPipeline...") | |
pipe = FluxControlNetPipeline.from_pretrained( | |
model_path, controlnet=controlnet, torch_dtype=torch.bfloat16 | |
) | |
pipe.to(device) | |
print("Pipeline loaded.") | |
# Define constants for seed generation and maximum pixel budget | |
MAX_SEED = 1000000 | |
MAX_PIXEL_BUDGET = 1024 * 1024 | |
# Function to process input image before upscaling | |
def process_input(input_image, upscale_factor, **kwargs): | |
print(f"Processing input image with upscale factor: {upscale_factor}") | |
w, h = input_image.size | |
w_original, h_original = w, h | |
aspect_ratio = w / h | |
was_resized = False | |
# Resize the input image if the output image would exceed the pixel budget | |
if w * h * upscale_factor**2 > MAX_PIXEL_BUDGET: | |
warnings.warn( | |
f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing to ({int(aspect_ratio * MAX_PIXEL_BUDGET ** 0.5 // upscale_factor), int(MAX_PIXEL_BUDGET ** 0.5 // aspect_ratio // upscale_factor)}) pixels." | |
) | |
print("Input image is too large, resizing...") | |
gr.Info( | |
f"Requested output image is too large ({w * upscale_factor}x{h * upscale_factor}). Resizing input to ({int(aspect_ratio * MAX_PIXEL_BUDGET ** 0.5 // upscale_factor), int(MAX_PIXEL_BUDGET ** 0.5 // aspect_ratio // upscale_factor)}) pixels budget." | |
) | |
# Resize the input image to fit within the maximum pixel budget | |
input_image = input_image.resize( | |
( | |
int(aspect_ratio * MAX_PIXEL_BUDGET**0.5 // upscale_factor), | |
int(MAX_PIXEL_BUDGET**0.5 // aspect_ratio // upscale_factor), | |
) | |
) | |
was_resized = True | |
print(f"Image resized to: {input_image.size}") | |
# Ensure that the dimensions are multiples of 8 (required by the model) | |
w, h = input_image.size | |
w = w - w % 8 | |
h = h - h % 8 | |
print(f"Resizing image to be multiple of 8: ({w}, {h})") | |
return input_image.resize((w, h)), w_original, h_original, was_resized | |
# Define inference function with GPU duration hint | |
def infer( | |
seed, | |
randomize_seed, | |
input_image, | |
num_inference_steps, | |
upscale_factor, | |
controlnet_conditioning_scale, | |
progress=gr.Progress(track_tqdm=True), | |
): | |
print(f"Starting inference with seed: {seed}, randomize_seed: {randomize_seed}") | |
# Randomize the seed if the option is selected | |
if randomize_seed: | |
seed = random.randint(0, MAX_SEED) | |
print(f"Randomized seed: {seed}") | |
true_input_image = input_image | |
# Process the input image for upscaling | |
input_image, w_original, h_original, was_resized = process_input( | |
input_image, upscale_factor | |
) | |
print(f"Processed input image. Original size: ({w_original}, {h_original}), Processed size: {input_image.size}") | |
# Rescale the input image by the upscale factor | |
w, h = input_image.size | |
control_image = input_image.resize((w * upscale_factor, h * upscale_factor)) | |
print(f"Control image resized to: {control_image.size}") | |
# Create a random number generator with the provided seed | |
generator = torch.Generator().manual_seed(seed) | |
gr.Info("Upscaling image...") | |
print("Running the pipeline to generate output image...") | |
# Run the pipeline to generate the output image | |
image = pipe( | |
prompt="", # No specific prompt is used here | |
control_image=control_image, | |
controlnet_conditioning_scale=controlnet_conditioning_scale, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=3.5, # Guidance scale for image generation | |
height=control_image.size[1], | |
width=control_image.size[0], | |
generator=generator, | |
).images[0] | |
print("Image generation completed.") | |
# If the image was resized during processing, resize it back to the original target size | |
if was_resized: | |
gr.Info( | |
f"Resizing output image to targeted {w_original * upscale_factor}x{h_original * upscale_factor} size." | |
) | |
print(f"Resizing output image to original target size: ({w_original * upscale_factor}, {h_original * upscale_factor})") | |
# Resize the generated image to the desired output size | |
image = image.resize((w_original * upscale_factor, h_original * upscale_factor)) | |
print(f"Final output image size: {image.size}") | |
image.save("output.jpg") | |
print("Output image saved as 'output.jpg'") | |
# Return the original input image, generated image, and seed value | |
return [true_input_image, image, seed] | |
# Create the Gradio interface | |
with gr.Blocks(theme="Nymbo/Nymbo_Theme", css=css) as demo: | |
gr.HTML("<center><h1>FLUX.1-Dev Upscaler</h1></center>") | |
# Define the button to start the upscaling process | |
with gr.Row(): | |
run_button = gr.Button(value="Run") | |
# Define the input elements for the upscaling parameters | |
with gr.Row(): | |
with gr.Column(scale=4): | |
input_im = gr.Image(label="Input Image", type="pil") # Input image | |
with gr.Column(scale=1): | |
num_inference_steps = gr.Slider( | |
label="Number of Inference Steps", # Slider to set the number of inference steps | |
minimum=8, | |
maximum=50, | |
step=1, | |
value=28, | |
) | |
upscale_factor = gr.Slider( | |
label="Upscale Factor", # Slider to set the upscale factor | |
minimum=1, | |
maximum=4, | |
step=1, | |
value=4, | |
) | |
controlnet_conditioning_scale = gr.Slider( | |
label="Controlnet Conditioning Scale", # Slider for controlnet conditioning scale | |
minimum=0.1, | |
maximum=1.5, | |
step=0.1, | |
value=0.6, | |
) | |
seed = gr.Slider( | |
label="Seed", # Slider to set the random seed | |
minimum=0, | |
maximum=MAX_SEED, | |
step=1, | |
value=42, | |
) | |
randomize_seed = gr.Checkbox(label="Randomize seed", value=True) # Checkbox to randomize the seed | |
# Define the output element to display the input and output images | |
with gr.Row(): | |
result = ImageSlider(label="Input / Output", type="pil", interactive=True) | |
# Define examples for users to try out | |
examples = gr.Examples( | |
examples=[ | |
[42, False, "examples/image_2.jpg", 28, 4, 0.6], | |
[42, False, "examples/image_4.jpg", 28, 4, 0.6], | |
], | |
inputs=[ | |
seed, | |
randomize_seed, | |
input_im, | |
num_inference_steps, | |
upscale_factor, | |
controlnet_conditioning_scale, | |
], | |
fn=infer, # Function to call for the examples | |
outputs=result, | |
cache_examples="lazy", | |
) | |
# Define the action for the run button | |
gr.on( | |
[run_button.click], | |
fn=infer, | |
inputs=[ | |
seed, | |
randomize_seed, | |
input_im, | |
num_inference_steps, | |
upscale_factor, | |
controlnet_conditioning_scale, | |
], | |
outputs=result, | |
show_api=False, | |
) | |
# Launch the Gradio app | |
# The queue is used to handle multiple requests, sharing is disabled for privacy | |
print("Launching Gradio app...") | |
demo.queue().launch(share=False, show_api=False) |