File size: 5,338 Bytes
74853ff
c0a8ff9
333ce15
 
 
 
 
911d75e
74853ff
333ce15
e60d8ba
 
74853ff
333ce15
55fde07
 
333ce15
55fde07
333ce15
 
 
 
e60d8ba
333ce15
c0a8ff9
 
 
911d75e
333ce15
 
 
 
e60d8ba
c0a8ff9
caee6ce
 
88447c1
4c346fd
 
 
461fe92
7b58e8b
 
461fe92
616ddf9
4c346fd
 
 
7b58e8b
f6dc99f
461fe92
 
f6dc99f
4c346fd
 
 
7b58e8b
f6dc99f
461fe92
208cb2b
 
72b4d00
 
 
7b58e8b
77a0a88
208cb2b
461fe92
616ddf9
35b4284
f6dc99f
461fe92
caee6ce
333ce15
 
 
 
f6e6e5d
 
 
 
 
333ce15
 
 
d39d36a
869d316
 
4cc1869
0abaf64
 
 
 
 
 
 
 
 
 
 
 
 
 
333ce15
 
 
 
 
 
 
 
 
 
 
 
 
9f93059
333ce15
c0a8ff9
461fe92
 
 
 
208cb2b
461fe92
333ce15
 
 
 
 
 
 
461fe92
333ce15
 
89c6fa9
333ce15
 
 
461fe92
 
 
 
 
 
 
 
208cb2b
 
 
 
461fe92
 
 
 
333ce15
82699fa
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import gradio as gr
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline, AutoencoderKL, EulerAncestralDiscreteScheduler
import torch
import numpy as np
import cv2
from PIL import Image
import spaces


# 🌟 Auto-detect device (CPU/GPU)
device = "cuda"
precision = torch.float16

# 🏗️ Load ControlNet model for Canny edge detection
# xinsir/controlnet-canny-sdxl-1.0
# diffusers/controlnet-canny-sdxl-1.0
controlnet = ControlNetModel.from_pretrained(
    "xinsir/controlnet-canny-sdxl-1.0",
    torch_dtype=precision
)

# when test with other base model, you need to change the vae also.
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=precision)

# Scheduler
eulera_scheduler = EulerAncestralDiscreteScheduler.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", subfolder="scheduler")

# Stable Diffusion Model
pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    vae=vae,
    torch_dtype=precision,
    scheduler=eulera_scheduler,
)

# Load lora (giving it a name makes it active when using the name in the prompt)
pipe.load_lora_weights("ostris/ikea-instructions-lora-sdxl", weight_name="ikea_instructions_xl_v1_5.safetensors", adapter_name="ikea")
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
pipe.load_lora_weights('e-n-v-y/envy-junkworld-xl-01', weight_name='EnvyJunkworldXL01.safetensors', adapter_name="junkworld")

pipe.disable_lora()

def activate_ikea_lora():
    print("Activating IKEA LoRa")
    pipe.disable_lora()
    while pipe.get_active_adapters()[0] != "ikea":
        pipe.set_adapters("ikea")
    pipe.enable_lora()
    print("IKEA LoRa active!")

def activate_pixel_lora():
    print("Activating PixelArt LoRa")
    pipe.disable_lora()
    while pipe.get_active_adapters()[0] != "pixel":
        pipe.set_adapters("pixel")
    pipe.enable_lora()
    print("PixelArt LoRa active!")

def activate_junkworld_lora():
    print("Activating JunkWorld LoRa")
    pipe.disable_lora()
    while pipe.get_active_adapters()[0] != "junkworld":
        pipe.set_adapters("junkworld")
    pipe.enable_lora()
    print("JunkWorld LoRa active!")

def disable_loras():
    print("Deactivating LoRas")
    pipe.disable_lora()
    print("All LoRas deactivated!")

pipe.to(device)

# 📸 Edge detection function using OpenCV (Canny)
@spaces.GPU
def apply_canny(image, low_threshold, high_threshold):
    image = np.array(image)
    image = cv2.Canny(image, low_threshold, high_threshold)
    image = image[:, :, None]
    image = np.concatenate([image, image, image], axis=2)
    return Image.fromarray(image)

# 🎨 Image generation function
@spaces.GPU
def generate_image(prompt, input_image, low_threshold, high_threshold, strength, guidance, controlnet_conditioning_scale):

    print(pipe.get_active_adapters())

    # Apply edge detection
    edge_detected = apply_canny(input_image, low_threshold, high_threshold)
    
    # Generate styled image using ControlNet
    result = pipe(
        prompt=prompt,
        image=edge_detected,
        num_inference_steps=30,
        guidance_scale=guidance,
        controlnet_conditioning_scale=float(controlnet_conditioning_scale),
        strength=strength
    ).images[0]
    
    return edge_detected, result

# 🖥️ Gradio UI
with gr.Blocks() as demo:
    gr.Markdown("# 🏗️ 3D Screenshot to Styled Render with ControlNet")

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(label="Upload 3D Screenshot", type="pil")
            prompt = gr.Textbox(label="Style Prompt", placeholder="e.g., Futuristic building in sunset")
            
            low_threshold = gr.Slider(50, 150, value=100, label="Canny Edge Low Threshold")
            high_threshold = gr.Slider(100, 200, value=150, label="Canny Edge High Threshold")
            
            strength = gr.Slider(0.1, 1.0, value=0.7, label="Denoising Strength")
            guidance = gr.Slider(1, 20, value=7.5, label="Guidance Scale (Creativity)")
            controlnet_conditioning_scale = gr.Slider(0, 1, value=0.5, step=0.01, label="ControlNet Conditioning Scale")

            with gr.Row():
                ikea_lora_button = gr.Button("IKEA Instructions")
                pixel_lora_button = gr.Button("Pixel Art")
                junkworld_lora_button = gr.Button("Junk World")
                disable_lora_button = gr.Button("Disable LoRas")
            
            generate_button = gr.Button("Generate Styled Image")

        with gr.Column():
            edge_output = gr.Image(label="Edge Detected Image")
            result_output = gr.Image(label="Generated Styled Image")

    # 🔗 Generate Button Action
    generate_button.click(
        fn=generate_image,
        inputs=[prompt, input_image, low_threshold, high_threshold, strength, guidance, controlnet_conditioning_scale],
        outputs=[edge_output, result_output]
    )

    ikea_lora_button.click(
        fn = activate_ikea_lora,
    )

    pixel_lora_button.click(
        fn = activate_pixel_lora,
    )

    junkworld_lora_button.click(
        fn = activate_junkworld_lora,
    )

    disable_lora_button.click(
        fn = disable_loras,
    )

# 🚀 Launch the app
demo.launch()