Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler | |
import torch | |
import numpy as np | |
import cv2 | |
from PIL import Image | |
import spaces | |
# ๐ Auto-detect device (CPU/GPU) | |
device = "cuda" | |
precision = torch.float16 | |
# ๐๏ธ Load ControlNet model for Canny edge detection | |
# xinsir/controlnet-canny-sdxl-1.0 | |
# diffusers/controlnet-canny-sdxl-1.0 | |
controlnet = ControlNetModel.from_pretrained( | |
"xinsir/controlnet-canny-sdxl-1.0", | |
torch_dtype=precision | |
) | |
# when test with other base model, you need to change the vae also. | |
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=precision) | |
# Scheduler | |
eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler") | |
# Stable Diffusion Model | |
pipe = StableDiffusionXLControlNetPipeline.from_pretrained( | |
"stabilityai/stable-diffusion-xl-base-1.0", | |
controlnet=controlnet, | |
vae=vae, | |
torch_dtype=precision, | |
scheduler=eulera_scheduler, | |
) | |
# Load lora (giving it a name makes it active when using the name in the prompt) | |
pipe.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea") | |
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel") | |
pipe.load_lora_weights('e-n-v-y/envy-junkworld-xl-01', weight_name='EnvyJunkworldXL01.safetensors', adapter_name="junkworld") | |
pipe.disable_lora() | |
def activate_ikea_lora(): | |
print("Activating IKEA LoRa") | |
pipe.disable_lora() | |
while pipe.get_active_adapters()[0] != "ikea": | |
pipe.set_adapters("ikea") | |
pipe.enable_lora() | |
print("IKEA LoRa active!") | |
def activate_pixel_lora(): | |
print("Activating PixelArt LoRa") | |
pipe.disable_lora() | |
while pipe.get_active_adapters()[0] != "pixel": | |
pipe.set_adapters("pixel") | |
pipe.enable_lora() | |
print("PixelArt LoRa active!") | |
def activate_junkworld_lora(): | |
print("Activating JunkWorld LoRa") | |
pipe.disable_lora() | |
while pipe.get_active_adapters()[0] != "junkworld": | |
pipe.set_adapters("junkworld") | |
pipe.enable_lora() | |
print("JunkWorld LoRa active!") | |
def disable_loras(): | |
print("Deactivating LoRas") | |
pipe.disable_lora() | |
print("All LoRas deactivated!") | |
pipe.to(device) | |
# ๐ธ Edge detection function using OpenCV (Canny) | |
def apply_canny(image, low_threshold, high_threshold): | |
image = np.array(image) | |
image = cv2.Canny(image, low_threshold, high_threshold) | |
image = image[:, :, None] | |
image = np.concatenate([image, image, image], axis=2) | |
return Image.fromarray(image) | |
# ๐จ Image generation function | |
def generate_image(prompt, input_image, low_threshold, high_threshold, strength, guidance, controlnet_conditioning_scale): | |
print(pipe.get_active_adapters()) | |
# Apply edge detection | |
edge_detected = apply_canny(input_image, low_threshold, high_threshold) | |
# Generate styled image using ControlNet | |
result = pipe( | |
prompt=prompt, | |
image=edge_detected, | |
num_inference_steps=30, | |
guidance_scale=guidance, | |
controlnet_conditioning_scale=float(controlnet_conditioning_scale), | |
strength=strength | |
).images[0] | |
return edge_detected, result | |
# ๐ฅ๏ธ Gradio UI | |
with gr.Blocks() as demo: | |
gr.Markdown("# ๐๏ธ 3D Screenshot to Styled Render with ControlNet") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(label="Upload 3D Screenshot", type="pil") | |
prompt = gr.Textbox(label="Style Prompt", placeholder="e.g., Futuristic building in sunset") | |
low_threshold = gr.Slider(50, 150, value=100, label="Canny Edge Low Threshold") | |
high_threshold = gr.Slider(100, 200, value=150, label="Canny Edge High Threshold") | |
strength = gr.Slider(0.1, 1.0, value=0.7, label="Denoising Strength") | |
guidance = gr.Slider(1, 20, value=7.5, label="Guidance Scale (Creativity)") | |
controlnet_conditioning_scale = gr.Slider(0, 1, value=0.5, step=0.01, label="ControlNet Conditioning Scale") | |
with gr.Row(): | |
ikea_lora_button = gr.Button("IKEA Instructions") | |
pixel_lora_button = gr.Button("Pixel Art") | |
junkworld_lora_button = gr.Button("Junk World") | |
disable_lora_button = gr.Button("Disable LoRas") | |
generate_button = gr.Button("Generate Styled Image") | |
with gr.Column(): | |
edge_output = gr.Image(label="Edge Detected Image") | |
result_output = gr.Image(label="Generated Styled Image") | |
# ๐ Generate Button Action | |
generate_button.click( | |
fn=generate_image, | |
inputs=[prompt, input_image, low_threshold, high_threshold, strength, guidance, controlnet_conditioning_scale], | |
outputs=[edge_output, result_output] | |
) | |
ikea_lora_button.click( | |
fn = activate_ikea_lora, | |
) | |
pixel_lora_button.click( | |
fn = activate_pixel_lora, | |
) | |
junkworld_lora_button.click( | |
fn = activate_junkworld_lora, | |
) | |
disable_lora_button.click( | |
fn = disable_loras, | |
) | |
# ๐ Launch the app | |
demo.launch() |